diff --git a/docs/diffusers/installation.md b/docs/diffusers/installation.md index d0f9ad2fa3..c32c624947 100644 --- a/docs/diffusers/installation.md +++ b/docs/diffusers/installation.md @@ -12,7 +12,7 @@ specific language governing permissions and limitations under the License. # Installation -🤗 Diffusers is tested on Python 3.8+, MindSpore 2.2.10+. Follow the installation instructions below for the deep learning library you are using: +🤗 Diffusers is tested on Python 3.8+, MindSpore 2.3+. Follow the installation instructions below for the deep learning library you are using: - [MindSpore](https://www.mindspore.cn/install) installation instructions diff --git a/docs/diffusers/limitations.md b/docs/diffusers/limitations.md index 30b6153c71..6a3a7fcedf 100644 --- a/docs/diffusers/limitations.md +++ b/docs/diffusers/limitations.md @@ -18,6 +18,15 @@ Due to differences in framework, some APIs & models will not be identical to [hu Unlike the output `posterior = DiagonalGaussianDistribution(latent)`, which can do sampling by `posterior.sample()`. We can only output the `latent` and then do sampling through `AutoencoderKL.diag_gauss_dist.sample(latent)`. +### `self.config` in `construct()` + +For many models, parameters used in initialization will be registered in `self.config`. They are often accessed during the `construct` like using `if self.config.xxx == xxx` to determine execution paths in origin 🤗diffusers. However getting attributes like this is not supported by static graph syntax of MindSpore. Two feasible replacement options are + +- set new attributes in initialization for `self` like `self.xxx = self.config.xxx`, then use `self.xxx` in `construct` instead. +- use `self.config["xxx"]` as `self.config` is an `OrderedDict` and getting items like this is supported in static graph mode. + +When `self.config.xxx` changed, we change `self.xxx` and `self.config["xxx"]` both. + ## Models The table below represents the current support in mindone/diffusers for each of those modules, whether they have support in Pynative fp16 mode, Graph fp16 mode, Pynative fp32 mode or Graph fp32 mode. @@ -58,7 +67,6 @@ The table below represents the current support in mindone/diffusers for each of The table below represents the current support in mindone/diffusers for each of those pipelines in **MindSpore 2.3.0**, whether they have support in Pynative fp16 mode, Graph fp16 mode, Pynative fp32 mode or Graph fp32 mode. -> Hint: Due to the precision issue with GroupNorm affecting almost all pipelines under FP16, leading to inference > precision issues of pipelines, the experiments in the table below default to upcasting GroupNorm to FP32 to avoid > this issue. diff --git a/mindone/diffusers/__init__.py b/mindone/diffusers/__init__.py index 772c1f3611..25a995e641 100644 --- a/mindone/diffusers/__init__.py +++ b/mindone/diffusers/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.27.1" +__version__ = "0.29.2" from typing import TYPE_CHECKING @@ -13,6 +13,7 @@ _import_structure = { "configuration_utils": ["ConfigMixin"], + "loaders": ["FromOriginalModelMixin"], "models": [ "AsymmetricAutoencoderKL", "AutoencoderKL", @@ -20,21 +21,28 @@ "AutoencoderTiny", "ConsistencyDecoderVAE", "ControlNetModel", + "ControlNetXSAdapter", + "DiTTransformer2DModel", + "HunyuanDiT2DModel", "I2VGenXLUNet", "Kandinsky3UNet", "ModelMixin", "MotionAdapter", "MultiAdapter", + "PixArtTransformer2DModel", "PriorTransformer", + "SD3ControlNetModel", + "SD3MultiControlNetModel", + "SD3Transformer2DModel", "T2IAdapter", "T5FilmDecoder", "Transformer2DModel", - "SD3Transformer2DModel", "StableCascadeUNet", "UNet1DModel", "UNet2DConditionModel", "UNet2DModel", "UNet3DConditionModel", + "UNetControlNetXSModel", "UNetMotionModel", "UNetSpatioTemporalConditionModel", "UVit2DModel", @@ -51,6 +59,7 @@ ], "pipelines": [ "AnimateDiffPipeline", + "AnimateDiffSDXLPipeline", "AnimateDiffVideoToVideoPipeline", "BlipDiffusionControlNetPipeline", "BlipDiffusionPipeline", @@ -60,6 +69,7 @@ "DDPMPipeline", "DiffusionPipeline", "DiTPipeline", + "HunyuanDiTPipeline", "I2VGenXLPipeline", "IFImg2ImgPipeline", "IFImg2ImgSuperResolutionPipeline", @@ -90,17 +100,23 @@ "LatentConsistencyModelPipeline", "LDMSuperResolutionPipeline", "LDMTextToImagePipeline", + "MarigoldDepthPipeline", + "MarigoldNormalsPipeline", "PixArtAlphaPipeline", + "PixArtSigmaPipeline", "ShapEImg2ImgPipeline", "ShapEPipeline", "StableCascadeCombinedPipeline", "StableCascadeDecoderPipeline", "StableCascadePriorPipeline", + "StableDiffusion3ControlNetPipeline", + "StableDiffusion3Img2ImgPipeline", "StableDiffusion3Pipeline", "StableDiffusionAdapterPipeline", "StableDiffusionControlNetImg2ImgPipeline", "StableDiffusionControlNetInpaintPipeline", "StableDiffusionControlNetPipeline", + "StableDiffusionControlNetXSPipeline", "StableDiffusionDepth2ImgPipeline", "StableDiffusionDiffEditPipeline", "StableDiffusionGLIGENPipeline", @@ -116,6 +132,7 @@ "StableDiffusionXLControlNetImg2ImgPipeline", "StableDiffusionXLControlNetInpaintPipeline", "StableDiffusionXLControlNetPipeline", + "StableDiffusionXLControlNetXSPipeline", "StableDiffusionXLImg2ImgPipeline", "StableDiffusionXLInpaintPipeline", "StableDiffusionXLInstructPix2PixPipeline", @@ -172,12 +189,18 @@ AutoencoderTiny, ConsistencyDecoderVAE, ControlNetModel, + ControlNetXSAdapter, + DiTTransformer2DModel, + HunyuanDiT2DModel, I2VGenXLUNet, Kandinsky3UNet, ModelMixin, MotionAdapter, MultiAdapter, + PixArtTransformer2DModel, PriorTransformer, + SD3ControlNetModel, + SD3MultiControlNetModel, SD3Transformer2DModel, StableCascadeUNet, T2IAdapter, @@ -187,6 +210,7 @@ UNet2DConditionModel, UNet2DModel, UNet3DConditionModel, + UNetControlNetXSModel, UNetMotionModel, UNetSpatioTemporalConditionModel, UVit2DModel, @@ -203,6 +227,7 @@ ) from .pipelines import ( AnimateDiffPipeline, + AnimateDiffSDXLPipeline, AnimateDiffVideoToVideoPipeline, BlipDiffusionControlNetPipeline, BlipDiffusionPipeline, @@ -211,6 +236,7 @@ DDPMPipeline, DiffusionPipeline, DiTPipeline, + HunyuanDiTPipeline, I2VGenXLPipeline, IFImg2ImgPipeline, IFImg2ImgSuperResolutionPipeline, @@ -241,17 +267,23 @@ LatentConsistencyModelPipeline, LDMSuperResolutionPipeline, LDMTextToImagePipeline, + MarigoldDepthPipeline, + MarigoldNormalsPipeline, PixArtAlphaPipeline, + PixArtSigmaPipeline, ShapEImg2ImgPipeline, ShapEPipeline, StableCascadeCombinedPipeline, StableCascadeDecoderPipeline, StableCascadePriorPipeline, + StableDiffusion3ControlNetPipeline, + StableDiffusion3Img2ImgPipeline, StableDiffusion3Pipeline, StableDiffusionAdapterPipeline, StableDiffusionControlNetImg2ImgPipeline, StableDiffusionControlNetInpaintPipeline, StableDiffusionControlNetPipeline, + StableDiffusionControlNetXSPipeline, StableDiffusionDepth2ImgPipeline, StableDiffusionDiffEditPipeline, StableDiffusionGLIGENPipeline, @@ -267,6 +299,7 @@ StableDiffusionXLControlNetImg2ImgPipeline, StableDiffusionXLControlNetInpaintPipeline, StableDiffusionXLControlNetPipeline, + StableDiffusionXLControlNetXSPipeline, StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline, StableDiffusionXLInstructPix2PixPipeline, diff --git a/mindone/diffusers/callbacks.py b/mindone/diffusers/callbacks.py new file mode 100644 index 0000000000..38542407e3 --- /dev/null +++ b/mindone/diffusers/callbacks.py @@ -0,0 +1,156 @@ +from typing import Any, Dict, List + +from .configuration_utils import ConfigMixin, register_to_config +from .utils import CONFIG_NAME + + +class PipelineCallback(ConfigMixin): + """ + Base class for all the official callbacks used in a pipeline. This class provides a structure for implementing + custom callbacks and ensures that all callbacks have a consistent interface. + + Please implement the following: + `tensor_inputs`: This should return a list of tensor inputs specific to your callback. You will only be able to + include + variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. + `callback_fn`: This method defines the core functionality of your callback. + """ + + config_name = CONFIG_NAME + + @register_to_config + def __init__(self, cutoff_step_ratio=1.0, cutoff_step_index=None): + super().__init__() + + if (cutoff_step_ratio is None and cutoff_step_index is None) or ( + cutoff_step_ratio is not None and cutoff_step_index is not None + ): + raise ValueError("Either cutoff_step_ratio or cutoff_step_index should be provided, not both or none.") + + if cutoff_step_ratio is not None and ( + not isinstance(cutoff_step_ratio, float) or not (0.0 <= cutoff_step_ratio <= 1.0) + ): + raise ValueError("cutoff_step_ratio must be a float between 0.0 and 1.0.") + + @property + def tensor_inputs(self) -> List[str]: + raise NotImplementedError(f"You need to set the attribute `tensor_inputs` for {self.__class__}") + + def callback_fn(self, pipeline, step_index, timesteps, callback_kwargs) -> Dict[str, Any]: + raise NotImplementedError(f"You need to implement the method `callback_fn` for {self.__class__}") + + def __call__(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]: + return self.callback_fn(pipeline, step_index, timestep, callback_kwargs) + + +class MultiPipelineCallbacks: + """ + This class is designed to handle multiple pipeline callbacks. It accepts a list of PipelineCallback objects and + provides a unified interface for calling all of them. + """ + + def __init__(self, callbacks: List[PipelineCallback]): + self.callbacks = callbacks + + @property + def tensor_inputs(self) -> List[str]: + return [input for callback in self.callbacks for input in callback.tensor_inputs] + + def __call__(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]: + """ + Calls all the callbacks in order with the given arguments and returns the final callback_kwargs. + """ + for callback in self.callbacks: + callback_kwargs = callback(pipeline, step_index, timestep, callback_kwargs) + + return callback_kwargs + + +class SDCFGCutoffCallback(PipelineCallback): + """ + Callback function for Stable Diffusion Pipelines. After certain number of steps (set by `cutoff_step_ratio` or + `cutoff_step_index`), this callback will disable the CFG. + + Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step. + """ + + tensor_inputs = ["prompt_embeds"] + + def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]: + cutoff_step_ratio = self.config.cutoff_step_ratio + cutoff_step_index = self.config.cutoff_step_index + + # Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio + cutoff_step = ( + cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio) + ) + + if step_index == cutoff_step: + prompt_embeds = callback_kwargs[self.tensor_inputs[0]] + prompt_embeds = prompt_embeds[-1:] # "-1" denotes the embeddings for conditional text tokens. + + pipeline._guidance_scale = 0.0 + + callback_kwargs[self.tensor_inputs[0]] = prompt_embeds + return callback_kwargs + + +class SDXLCFGCutoffCallback(PipelineCallback): + """ + Callback function for Stable Diffusion XL Pipelines. After certain number of steps (set by `cutoff_step_ratio` or + `cutoff_step_index`), this callback will disable the CFG. + + Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step. + """ + + tensor_inputs = ["prompt_embeds", "add_text_embeds", "add_time_ids"] + + def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]: + cutoff_step_ratio = self.config.cutoff_step_ratio + cutoff_step_index = self.config.cutoff_step_index + + # Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio + cutoff_step = ( + cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio) + ) + + if step_index == cutoff_step: + prompt_embeds = callback_kwargs[self.tensor_inputs[0]] + prompt_embeds = prompt_embeds[-1:] # "-1" denotes the embeddings for conditional text tokens. + + add_text_embeds = callback_kwargs[self.tensor_inputs[1]] + add_text_embeds = add_text_embeds[-1:] # "-1" denotes the embeddings for conditional pooled text tokens + + add_time_ids = callback_kwargs[self.tensor_inputs[2]] + add_time_ids = add_time_ids[-1:] # "-1" denotes the embeddings for conditional added time vector + + pipeline._guidance_scale = 0.0 + + callback_kwargs[self.tensor_inputs[0]] = prompt_embeds + callback_kwargs[self.tensor_inputs[1]] = add_text_embeds + callback_kwargs[self.tensor_inputs[2]] = add_time_ids + return callback_kwargs + + +class IPAdapterScaleCutoffCallback(PipelineCallback): + """ + Callback function for any pipeline that inherits `IPAdapterMixin`. After certain number of steps (set by + `cutoff_step_ratio` or `cutoff_step_index`), this callback will set the IP Adapter scale to `0.0`. + + Note: This callback mutates the IP Adapter attention processors by setting the scale to 0.0 after the cutoff step. + """ + + tensor_inputs = [] + + def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]: + cutoff_step_ratio = self.config.cutoff_step_ratio + cutoff_step_index = self.config.cutoff_step_index + + # Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio + cutoff_step = ( + cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio) + ) + + if step_index == cutoff_step: + pipeline.set_ip_adapter_scale(0.0) + return callback_kwargs diff --git a/mindone/diffusers/configuration_utils.py b/mindone/diffusers/configuration_utils.py index 018ddf7a79..a95d8b3dc4 100644 --- a/mindone/diffusers/configuration_utils.py +++ b/mindone/diffusers/configuration_utils.py @@ -13,7 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" ConfigMixin base class and utilities.""" +"""ConfigMixin base class and utilities.""" import functools import inspect import json @@ -306,9 +306,9 @@ def load_config( force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. - resume_download (`bool`, *optional*, defaults to `False`): - Whether or not to resume downloading the model weights and configuration files. If set to `False`, any - incompletely downloaded files are deleted. + resume_download: + Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1 + of Diffusers. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. @@ -336,8 +336,10 @@ def load_config( """ cache_dir = kwargs.pop("cache_dir", None) + local_dir = kwargs.pop("local_dir", None) + local_dir_use_symlinks = kwargs.pop("local_dir_use_symlinks", "auto") force_download = kwargs.pop("force_download", False) - resume_download = kwargs.pop("resume_download", False) + resume_download = kwargs.pop("resume_download", None) proxies = kwargs.pop("proxies", None) token = kwargs.pop("token", None) local_files_only = kwargs.pop("local_files_only", False) @@ -360,13 +362,13 @@ def load_config( if os.path.isfile(pretrained_model_name_or_path): config_file = pretrained_model_name_or_path elif os.path.isdir(pretrained_model_name_or_path): - if os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)): - # Load from a PyTorch checkpoint - config_file = os.path.join(pretrained_model_name_or_path, cls.config_name) - elif subfolder is not None and os.path.isfile( + if subfolder is not None and os.path.isfile( os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name) ): config_file = os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name) + elif os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)): + # Load from a PyTorch checkpoint + config_file = os.path.join(pretrained_model_name_or_path, cls.config_name) else: raise EnvironmentError( f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}." @@ -386,6 +388,8 @@ def load_config( user_agent=user_agent, subfolder=subfolder, revision=revision, + local_dir=local_dir, + local_dir_use_symlinks=local_dir_use_symlinks, ) except RepositoryNotFoundError: raise EnvironmentError( @@ -446,8 +450,8 @@ def load_config( return outputs @staticmethod - def _get_init_keys(cls): - return set(dict(inspect.signature(cls.__init__).parameters).keys()) + def _get_init_keys(input_class): + return set(dict(inspect.signature(input_class.__init__).parameters).keys()) @classmethod def extract_init_dict(cls, config_dict, **kwargs): @@ -650,3 +654,20 @@ def inner_init(self, *args, **kwargs): init(self, *args, **init_kwargs) return inner_init + + +class LegacyConfigMixin(ConfigMixin): + r""" + A subclass of `ConfigMixin` to resolve class mapping from legacy classes (like `Transformer2DModel`) to more + pipeline-specific classes (like `DiTTransformer2DModel`). + """ + + @classmethod + def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs): + # To prevent depedency import problem. + from .models.model_loading_utils import _fetch_remapped_cls_from_config + + # resolve remapping + remapped_class = _fetch_remapped_cls_from_config(config, cls) + + return remapped_class.from_config(config, return_unused_kwargs, **kwargs) diff --git a/mindone/diffusers/image_processor.py b/mindone/diffusers/image_processor.py index 504f8e580e..f0e8d6e0f0 100644 --- a/mindone/diffusers/image_processor.py +++ b/mindone/diffusers/image_processor.py @@ -37,6 +37,25 @@ PipelineDepthInput = PipelineImageInput +def is_valid_image(image): + return isinstance(image, PIL.Image.Image) or isinstance(image, (np.ndarray, ms.Tensor)) and image.ndim in (2, 3) + + +def is_valid_image_imagelist(images): + # check if the image input is one of the supported formats for image and image list: + # it can be either one of below 3 + # (1) a 4d pytorch tensor or numpy array, + # (2) a valid image: PIL.Image.Image, 2-d np.ndarray or torch.Tensor (grayscale image), 3-d np.ndarray or torch.Tensor + # (3) a list of valid image + if isinstance(images, (np.ndarray, ms.Tensor)) and images.ndim == 4: + return True + elif is_valid_image(images): + return True + elif isinstance(images, list): + return all(is_valid_image(image) for image in images) + return False + + class VaeImageProcessor(ConfigMixin): """ Image processor for VAE. @@ -80,7 +99,6 @@ def __init__( " if you intended to convert the image into RGB format, please set `do_convert_grayscale = False`.", " if you intended to convert the image into grayscale format, please set `do_convert_rgb = False`", ) - self.config.do_convert_rgb = False @staticmethod def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]: @@ -173,8 +191,9 @@ def blur(image: PIL.Image.Image, blur_factor: int = 4) -> PIL.Image.Image: @staticmethod def get_crop_region(mask_image: PIL.Image.Image, width: int, height: int, pad=0): """ - Finds a rectangular region that contains all masked ares in an image, and expands region to match the aspect ratio of the original image; - for example, if user drew mask in a 128x32 region, and the dimensions for processing are 512x512, the region will be expanded to 128x128. + Finds a rectangular region that contains all masked ares in an image, and expands region to match the aspect + ratio of the original image; for example, if user drew mask in a 128x32 region, and the dimensions for + processing are 512x512, the region will be expanded to 128x128. Args: mask_image (PIL.Image.Image): Mask image. @@ -183,7 +202,8 @@ def get_crop_region(mask_image: PIL.Image.Image, width: int, height: int, pad=0) pad (int, optional): Padding to be added to the crop region. Defaults to 0. Returns: - tuple: (x1, y1, x2, y2) represent a rectangular region that contains all masked ares in an image and matches the original aspect ratio. + tuple: (x1, y1, x2, y2) represent a rectangular region that contains all masked ares in an image and + matches the original aspect ratio. """ mask_image = mask_image.convert("L") @@ -265,8 +285,8 @@ def _resize_and_fill( height: int, ) -> PIL.Image.Image: """ - Resize the image to fit within the specified width and height, maintaining the aspect ratio, - and then center the image within the dimensions, filling empty with data from image. + Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center + the image within the dimensions, filling empty with data from image. Args: image: The image to resize. @@ -310,8 +330,8 @@ def _resize_and_crop( height: int, ) -> PIL.Image.Image: """ - Resize the image to fit within the specified width and height, maintaining the aspect ratio, - and then center the image within the dimensions, cropping the excess. + Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center + the image within the dimensions, cropping the excess. Args: image: The image to resize. @@ -348,12 +368,12 @@ def resize( The width to resize to. resize_mode (`str`, *optional*, defaults to `default`): The resize mode to use, can be one of `default` or `fill`. If `default`, will resize the image to fit - within the specified width and height, and it may not maintaining the original aspect ratio. - If `fill`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image - within the dimensions, filling empty with data from image. - If `crop`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image - within the dimensions, cropping the excess. - Note that resize_mode `fill` and `crop` are only supported for PIL image input. + within the specified width and height, and it may not maintaining the original aspect ratio. If `fill`, + will resize the image to fit within the specified width and height, maintaining the aspect ratio, and + then center the image within the dimensions, filling empty with data from image. If `crop`, will resize + the image to fit within the specified width and height, maintaining the aspect ratio, and then center + the image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only + supported for PIL image input. Returns: `PIL.Image.Image`, `np.ndarray` or `ms.Tensor`: @@ -458,19 +478,21 @@ def preprocess( Args: image (`pipeline_image_input`): - The image input, accepted formats are PIL images, NumPy arrays, PyTorch tensors; Also accept list of supported formats. + The image input, accepted formats are PIL images, NumPy arrays, PyTorch tensors; Also accept list of + supported formats. height (`int`, *optional*, defaults to `None`): - The height in preprocessed image. If `None`, will use the `get_default_height_width()` to get default height. + The height in preprocessed image. If `None`, will use the `get_default_height_width()` to get default + height. width (`int`, *optional*`, defaults to `None`): - The width in preprocessed. If `None`, will use get_default_height_width()` to get the default width. + The width in preprocessed. If `None`, will use get_default_height_width()` to get the default width. resize_mode (`str`, *optional*, defaults to `default`): - The resize mode, can be one of `default` or `fill`. If `default`, will resize the image to fit - within the specified width and height, and it may not maintaining the original aspect ratio. - If `fill`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image - within the dimensions, filling empty with data from image. - If `crop`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image - within the dimensions, cropping the excess. - Note that resize_mode `fill` and `crop` are only supported for PIL image input. + The resize mode, can be one of `default` or `fill`. If `default`, will resize the image to fit within + the specified width and height, and it may not maintaining the original aspect ratio. If `fill`, will + resize the image to fit within the specified width and height, maintaining the aspect ratio, and then + center the image within the dimensions, filling empty with data from image. If `crop`, will resize the + image to fit within the specified width and height, maintaining the aspect ratio, and then center the + image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only + supported for PIL image input. crops_coords (`List[Tuple[int, int, int, int]]`, *optional*, defaults to `None`): The crop coordinates for each image in the batch. If `None`, will not crop the image. """ @@ -494,12 +516,27 @@ def preprocess( else: image = np.expand_dims(image, axis=-1) - if isinstance(image, supported_formats): - image = [image] - elif not (isinstance(image, list) and all(isinstance(i, supported_formats) for i in image)): + if isinstance(image, list) and isinstance(image[0], np.ndarray) and image[0].ndim == 4: + warnings.warn( + "Passing `image` as a list of 4d np.ndarray is deprecated." + "Please concatenate the list along the batch dimension and pass it as a single 4d np.ndarray", + FutureWarning, + ) + image = np.concatenate(image, axis=0) + if isinstance(image, list) and isinstance(image[0], ms.Tensor) and image[0].ndim == 4: + warnings.warn( + "Passing `image` as a list of 4d torch.Tensor is deprecated." + "Please concatenate the list along the batch dimension and pass it as a single 4d torch.Tensor", + FutureWarning, + ) + image = ops.cat(image, axis=0) + + if not is_valid_image_imagelist(image): raise ValueError( - f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support {', '.join(supported_formats)}" + f"Input is in incorrect format. Currently, we only support {', '.join(str(x) for x in supported_formats)}" ) + if not isinstance(image, list): + image = [image] if isinstance(image[0], PIL.Image.Image): if crops_coords is not None: @@ -691,8 +728,8 @@ def __init__( @staticmethod def downsample(mask: ms.Tensor, batch_size: int, num_queries: int, value_embed_dim: int): """ - Downsamples the provided mask tensor to match the expected dimensions for scaled dot-product attention. - If the aspect ratio of the mask does not match the aspect ratio of the output image, a warning is issued. + Downsamples the provided mask tensor to match the expected dimensions for scaled dot-product attention. If the + aspect ratio of the mask does not match the aspect ratio of the output image, a warning is issued. Args: mask (`ms.Tensor`): @@ -739,3 +776,77 @@ def downsample(mask: ms.Tensor, batch_size: int, num_queries: int, value_embed_d ) return mask_downsample + + +class PixArtImageProcessor(VaeImageProcessor): + """ + Image processor for PixArt image resize and crop. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept + `height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method. + vae_scale_factor (`int`, *optional*, defaults to `8`): + VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor. + resample (`str`, *optional*, defaults to `lanczos`): + Resampling filter to use when resizing the image. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image to [-1,1]. + do_binarize (`bool`, *optional*, defaults to `False`): + Whether to binarize the image to 0/1. + do_convert_rgb (`bool`, *optional*, defaults to be `False`): + Whether to convert the images to RGB format. + do_convert_grayscale (`bool`, *optional*, defaults to be `False`): + Whether to convert the images to grayscale format. + """ + + @register_to_config + def __init__( + self, + do_resize: bool = True, + vae_scale_factor: int = 8, + resample: str = "lanczos", + do_normalize: bool = True, + do_binarize: bool = False, + do_convert_grayscale: bool = False, + ): + super().__init__( + do_resize=do_resize, + vae_scale_factor=vae_scale_factor, + resample=resample, + do_normalize=do_normalize, + do_binarize=do_binarize, + do_convert_grayscale=do_convert_grayscale, + ) + + @staticmethod + def classify_height_width_bin(height: int, width: int, ratios: dict) -> Tuple[int, int]: + """Returns binned height and width.""" + ar = float(height / width) + closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar)) + default_hw = ratios[closest_ratio] + return int(default_hw[0]), int(default_hw[1]) + + @staticmethod + def resize_and_crop_tensor(samples: ms.Tensor, new_width: int, new_height: int) -> ms.Tensor: + orig_height, orig_width = samples.shape[2], samples.shape[3] + + # Check if resizing is needed + if orig_height != new_height or orig_width != new_width: + ratio = max(new_height / orig_height, new_width / orig_width) + resized_width = int(orig_width * ratio) + resized_height = int(orig_height * ratio) + + # Resize + samples = ops.interpolate( + samples, size=(resized_height, resized_width), mode="bilinear", align_corners=False + ) + + # Center Crop + start_x = (resized_width - new_width) // 2 + end_x = start_x + new_width + start_y = (resized_height - new_height) // 2 + end_y = start_y + new_height + samples = samples[:, :, start_y:end_y, start_x:end_x] + + return samples diff --git a/mindone/diffusers/loaders/__init__.py b/mindone/diffusers/loaders/__init__.py index 9f39c92d42..b052611520 100644 --- a/mindone/diffusers/loaders/__init__.py +++ b/mindone/diffusers/loaders/__init__.py @@ -49,8 +49,7 @@ def text_encoder_attn_modules(text_encoder): _import_structure = { - "autoencoder": ["FromOriginalVAEMixin"], - "controlnet": ["FromOriginalControlNetMixin"], + "single_file_model": ["FromOriginalModelMixin"], "ip_adapter": ["IPAdapterMixin"], "lora": ["LoraLoaderMixin", "SD3LoraLoaderMixin", "StableDiffusionXLLoraLoaderMixin"], "peft": ["PeftAdapterMixin"], @@ -61,12 +60,11 @@ def text_encoder_attn_modules(text_encoder): if TYPE_CHECKING: - from .autoencoder import FromOriginalVAEMixin - from .controlnet import FromOriginalControlNetMixin from .ip_adapter import IPAdapterMixin from .lora import LoraLoaderMixin, SD3LoraLoaderMixin, StableDiffusionXLLoraLoaderMixin from .peft import PeftAdapterMixin from .single_file import FromSingleFileMixin + from .single_file_model import FromOriginalModelMixin from .textual_inversion import TextualInversionLoaderMixin from .unet import UNet2DConditionLoadersMixin else: diff --git a/mindone/diffusers/loaders/autoencoder.py b/mindone/diffusers/loaders/autoencoder.py deleted file mode 100644 index d5d3c343d7..0000000000 --- a/mindone/diffusers/loaders/autoencoder.py +++ /dev/null @@ -1,143 +0,0 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from huggingface_hub.utils import validate_hf_hub_args - -from .single_file_utils import create_diffusers_vae_model_from_ldm, fetch_ldm_config_and_checkpoint - - -class FromOriginalVAEMixin: - """ - Load pretrained AutoencoderKL weights saved in the `.ckpt` or `.safetensors` format into a [`AutoencoderKL`]. - """ - - @classmethod - @validate_hf_hub_args - def from_single_file(cls, pretrained_model_link_or_path, **kwargs): - r""" - Instantiate a [`AutoencoderKL`] from pretrained ControlNet weights saved in the original `.ckpt` or - `.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default. - - Parameters: - pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*): - Can be either: - - A link to the `.ckpt` file (for example - `"https://huggingface.co//blob/main/.ckpt"`) on the Hub. - - A path to a *file* containing all pipeline weights. - config_file (`str`, *optional*): - Filepath to the configuration YAML file associated with the model. If not provided it will default to: - https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml - mindspore_dtype (`str` or `mindspore.dtype`, *optional*): - Override the default `mindspore_dtype` and load the model with another dtype. If `"auto"` is passed, the - dtype is automatically derived from the model's weights. - force_download (`bool`, *optional*, defaults to `False`): - Whether or not to force the (re-)download of the model weights and configuration files, overriding the - cached versions if they exist. - cache_dir (`Union[str, os.PathLike]`, *optional*): - Path to a directory where a downloaded pretrained model configuration is cached if the standard cache - is not used. - resume_download (`bool`, *optional*, defaults to `False`): - Whether or not to resume downloading the model weights and configuration files. If set to `False`, any - incompletely downloaded files are deleted. - proxies (`Dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. - local_files_only (`bool`, *optional*, defaults to `False`): - Whether to only load local model weights and configuration files or not. If set to True, the model - won't be downloaded from the Hub. - token (`str` or *bool*, *optional*): - The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from - `diffusers-cli login` (stored in `~/.huggingface`) is used. - revision (`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier - allowed by Git. - image_size (`int`, *optional*, defaults to 512): - The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable - Diffusion v2 base model. Use 768 for Stable Diffusion v2. - scaling_factor (`float`, *optional*, defaults to 0.18215): - The component-wise standard deviation of the trained latent space computed using the first batch of the - training set. This is used to scale the latent space to have unit variance when training the diffusion - model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the - diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z - = 1 / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution - Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. - kwargs (remaining dictionary of keyword arguments, *optional*): - Can be used to overwrite load and saveable variables (for example the pipeline components of the - specific pipeline class). The overwritten components are directly passed to the pipelines `__init__` - method. See example below for more information. - - - - Make sure to pass both `image_size` and `scaling_factor` to `from_single_file()` if you're loading - a VAE from SDXL or a Stable Diffusion v2 model or higher. - - - - Examples: - - ```py - from mindone.diffusers import AutoencoderKL - - url = "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors" # can also be local file - model = AutoencoderKL.from_single_file(url) - ``` - """ - - original_config_file = kwargs.pop("original_config_file", None) - config_file = kwargs.pop("config_file", None) - resume_download = kwargs.pop("resume_download", False) - force_download = kwargs.pop("force_download", False) - proxies = kwargs.pop("proxies", None) - token = kwargs.pop("token", None) - cache_dir = kwargs.pop("cache_dir", None) - local_files_only = kwargs.pop("local_files_only", None) - revision = kwargs.pop("revision", None) - mindspore_dtype = kwargs.pop("mindspore_dtype", None) - - class_name = cls.__name__ - - if (config_file is not None) and (original_config_file is not None): - raise ValueError( - "You cannot pass both `config_file` and `original_config_file` to `from_single_file`. Please use only one of these arguments." - ) - - original_config_file = original_config_file or config_file - original_config, checkpoint = fetch_ldm_config_and_checkpoint( - pretrained_model_link_or_path=pretrained_model_link_or_path, - class_name=class_name, - original_config_file=original_config_file, - resume_download=resume_download, - force_download=force_download, - proxies=proxies, - token=token, - revision=revision, - local_files_only=local_files_only, - cache_dir=cache_dir, - ) - - image_size = kwargs.pop("image_size", None) - scaling_factor = kwargs.pop("scaling_factor", None) - component = create_diffusers_vae_model_from_ldm( - class_name, - original_config, - checkpoint, - image_size=image_size, - scaling_factor=scaling_factor, - mindspore_dtype=mindspore_dtype, - ) - vae = component["vae"] - if mindspore_dtype is not None: - vae = vae.to(mindspore_dtype) - - return vae diff --git a/mindone/diffusers/loaders/controlnet.py b/mindone/diffusers/loaders/controlnet.py deleted file mode 100644 index bc7ba91be3..0000000000 --- a/mindone/diffusers/loaders/controlnet.py +++ /dev/null @@ -1,133 +0,0 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from huggingface_hub.utils import validate_hf_hub_args - -from .single_file_utils import create_diffusers_controlnet_model_from_ldm, fetch_ldm_config_and_checkpoint - - -class FromOriginalControlNetMixin: - """ - Load pretrained ControlNet weights saved in the `.ckpt` or `.safetensors` format into a [`ControlNetModel`]. - """ - - @classmethod - @validate_hf_hub_args - def from_single_file(cls, pretrained_model_link_or_path, **kwargs): - r""" - Instantiate a [`ControlNetModel`] from pretrained ControlNet weights saved in the original `.ckpt` or - `.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default. - - Parameters: - pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*): - Can be either: - - A link to the `.ckpt` file (for example - `"https://huggingface.co//blob/main/.ckpt"`) on the Hub. - - A path to a *file* containing all pipeline weights. - config_file (`str`, *optional*): - Filepath to the configuration YAML file associated with the model. If not provided it will default to: - https://raw.githubusercontent.com/lllyasviel/ControlNet/main/models/cldm_v15.yaml - mindspore_dtype (`str` or `mindspore.dtype`, *optional*): - Override the default `mindspore.dtype` and load the model with another dtype. If `"auto"` is passed, the - dtype is automatically derived from the model's weights. - force_download (`bool`, *optional*, defaults to `False`): - Whether or not to force the (re-)download of the model weights and configuration files, overriding the - cached versions if they exist. - cache_dir (`Union[str, os.PathLike]`, *optional*): - Path to a directory where a downloaded pretrained model configuration is cached if the standard cache - is not used. - resume_download (`bool`, *optional*, defaults to `False`): - Whether or not to resume downloading the model weights and configuration files. If set to `False`, any - incompletely downloaded files are deleted. - proxies (`Dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. - local_files_only (`bool`, *optional*, defaults to `False`): - Whether to only load local model weights and configuration files or not. If set to True, the model - won't be downloaded from the Hub. - token (`str` or *bool*, *optional*): - The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from - `diffusers-cli login` (stored in `~/.huggingface`) is used. - revision (`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier - allowed by Git. - image_size (`int`, *optional*, defaults to 512): - The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable - Diffusion v2 base model. Use 768 for Stable Diffusion v2. - upcast_attention (`bool`, *optional*, defaults to `None`): - Whether the attention computation should always be upcasted. - kwargs (remaining dictionary of keyword arguments, *optional*): - Can be used to overwrite load and saveable variables (for example the pipeline components of the - specific pipeline class). The overwritten components are directly passed to the pipelines `__init__` - method. See example below for more information. - - Examples: - - ```py - from mindone.diffusers import StableDiffusionControlNetPipeline, ControlNetModel - - url = "https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth" # can also be a local path - model = ControlNetModel.from_single_file(url) - - url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned.safetensors" # can also be a local path - pipe = StableDiffusionControlNetPipeline.from_single_file(url, controlnet=controlnet) - ``` - """ - original_config_file = kwargs.pop("original_config_file", None) - config_file = kwargs.pop("config_file", None) - resume_download = kwargs.pop("resume_download", False) - force_download = kwargs.pop("force_download", False) - proxies = kwargs.pop("proxies", None) - token = kwargs.pop("token", None) - cache_dir = kwargs.pop("cache_dir", None) - local_files_only = kwargs.pop("local_files_only", None) - revision = kwargs.pop("revision", None) - mindspore_dtype = kwargs.pop("mindspore_dtype", None) - - class_name = cls.__name__ - if (config_file is not None) and (original_config_file is not None): - raise ValueError( - "You cannot pass both `config_file` and `original_config_file` to `from_single_file`. Please use only one of these arguments." - ) - - original_config_file = config_file or original_config_file - original_config, checkpoint = fetch_ldm_config_and_checkpoint( - pretrained_model_link_or_path=pretrained_model_link_or_path, - class_name=class_name, - original_config_file=original_config_file, - resume_download=resume_download, - force_download=force_download, - proxies=proxies, - token=token, - revision=revision, - local_files_only=local_files_only, - cache_dir=cache_dir, - ) - - upcast_attention = kwargs.pop("upcast_attention", False) - image_size = kwargs.pop("image_size", None) - - component = create_diffusers_controlnet_model_from_ldm( - class_name, - original_config, - checkpoint, - upcast_attention=upcast_attention, - image_size=image_size, - mindspore_dtype=mindspore_dtype, - ) - controlnet = component["controlnet"] - if mindspore_dtype is not None: - controlnet = controlnet.to(mindspore_dtype) - - return controlnet diff --git a/mindone/diffusers/loaders/ip_adapter.py b/mindone/diffusers/loaders/ip_adapter.py index f073999936..5e36bca0da 100644 --- a/mindone/diffusers/loaders/ip_adapter.py +++ b/mindone/diffusers/loaders/ip_adapter.py @@ -23,8 +23,9 @@ from mindone.safetensors.mindspore import load_file from mindone.transformers import CLIPVisionModelWithProjection -from ..models.attention_processor import IPAdapterAttnProcessor +from ..models.attention_processor import AttnProcessor, IPAdapterAttnProcessor from ..utils import _get_model_file, logging +from .unet_loader_utils import _maybe_expand_lora_scales logger = logging.get_logger(__name__) @@ -52,26 +53,27 @@ def load_ip_adapter( with [`ModelMixin.save_pretrained`]. - A [mindspore state dict] subfolder (`str` or `List[str]`): - The subfolder location of a model file within a larger model repository on the Hub or locally. - If a list is passed, it should have the same length as `weight_name`. + The subfolder location of a model file within a larger model repository on the Hub or locally. If a + list is passed, it should have the same length as `weight_name`. weight_name (`str` or `List[str]`): The name of the weight file to load. If a list is passed, it should have the same length as `weight_name`. image_encoder_folder (`str`, *optional*, defaults to `image_encoder`): The subfolder location of the image encoder within a larger model repository on the Hub or locally. - Pass `None` to not load the image encoder. If the image encoder is located in a folder inside `subfolder`, - you only need to pass the name of the folder that contains image encoder weights, e.g. `image_encoder_folder="image_encoder"`. - If the image encoder is located in a folder other than `subfolder`, you should pass the path to the folder that contains image encoder weights, - for example, `image_encoder_folder="different_subfolder/image_encoder"`. + Pass `None` to not load the image encoder. If the image encoder is located in a folder inside + `subfolder`, you only need to pass the name of the folder that contains image encoder weights, e.g. + `image_encoder_folder="image_encoder"`. If the image encoder is located in a folder other than + `subfolder`, you should pass the path to the folder that contains image encoder weights, for example, + `image_encoder_folder="different_subfolder/image_encoder"`. cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory where a downloaded pretrained model configuration is cached if the standard cache is not used. force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. - resume_download (`bool`, *optional*, defaults to `False`): - Whether or not to resume downloading the model weights and configuration files. If set to `False`, any - incompletely downloaded files are deleted. + resume_download: + Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1 + of Diffusers. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. @@ -109,7 +111,7 @@ def load_ip_adapter( # Load the main state dict first. cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) - resume_download = kwargs.pop("resume_download", False) + resume_download = kwargs.pop("resume_download", None) proxies = kwargs.pop("proxies", None) local_files_only = kwargs.pop("local_files_only", None) token = kwargs.pop("token", None) @@ -187,27 +189,66 @@ def load_ip_adapter( unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet unet._load_ip_adapter_weights(state_dicts) + extra_loras = unet._load_ip_adapter_loras(state_dicts) + if extra_loras != {}: + # apply the IP Adapter Face ID LoRA weights + peft_config = getattr(unet, "peft_config", {}) + for k, lora in extra_loras.items(): + if f"faceid_{k}" not in peft_config: + self.load_lora_weights(lora, adapter_name=f"faceid_{k}") + self.set_adapters([f"faceid_{k}"], adapter_weights=[1.0]) + def set_ip_adapter_scale(self, scale): """ - Sets the conditioning scale between text and image. + Set IP-Adapter scales per-transformer block. Input `scale` could be a single config or a list of configs for + granular control over each IP-Adapter behavior. A config can be a float or a dictionary. Example: ```py - pipeline.set_ip_adapter_scale(0.5) + # To use original IP-Adapter + scale = 1.0 + pipeline.set_ip_adapter_scale(scale) + + # To use style block only + scale = { + "up": {"block_0": [0.0, 1.0, 0.0]}, + } + pipeline.set_ip_adapter_scale(scale) + + # To use style+layout blocks + scale = { + "down": {"block_2": [0.0, 1.0]}, + "up": {"block_0": [0.0, 1.0, 0.0]}, + } + pipeline.set_ip_adapter_scale(scale) + + # To use style and layout from 2 reference images + scales = [{"down": {"block_2": [0.0, 1.0]}}, {"up": {"block_0": [0.0, 1.0, 0.0]}}] + pipeline.set_ip_adapter_scale(scales) ``` """ unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet - for attn_processor in unet.attn_processors.values(): + if not isinstance(scale, list): + scale = [scale] + scale_configs = _maybe_expand_lora_scales(unet, scale, default_scale=0.0) + + for attn_name, attn_processor in unet.attn_processors.items(): if isinstance(attn_processor, (IPAdapterAttnProcessor)): - if not isinstance(scale, list): - scale = [scale] * len(attn_processor.scale) - if len(attn_processor.scale) != len(scale): + if len(scale_configs) != len(attn_processor.scale): raise ValueError( - f"`scale` should be a list of same length as the number if ip-adapters " - f"Expected {len(attn_processor.scale)} but got {len(scale)}." + f"Cannot assign {len(scale_configs)} scale_configs to " + f"{len(attn_processor.scale)} IP-Adapter." ) - attn_processor.scale = scale + elif len(scale_configs) == 1: + scale_configs = scale_configs * len(attn_processor.scale) + for i, scale_config in enumerate(scale_configs): + if isinstance(scale_config, dict): + for k, s in scale_config.items(): + if attn_name.startswith(k): + attn_processor.scale[i] = s + else: + attn_processor.scale[i] = scale_config def unload_ip_adapter(self): """ @@ -238,4 +279,8 @@ def unload_ip_adapter(self): self.config.encoder_hid_dim_type = None # restore original Unet attention processors layers - self.unet.set_default_attn_processor() + attn_procs = {} + for name, value in self.unet.attn_processors.items(): + attn_processor_class = AttnProcessor() + attn_procs[name] = attn_processor_class if isinstance(value, IPAdapterAttnProcessor) else value.__class__() + self.unet.set_attn_processor(attn_procs) diff --git a/mindone/diffusers/loaders/lora.py b/mindone/diffusers/loaders/lora.py index 821f84e65b..3b05a6bd9e 100644 --- a/mindone/diffusers/loaders/lora.py +++ b/mindone/diffusers/loaders/lora.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import copy import inspect import os from pathlib import Path @@ -31,17 +32,17 @@ _get_model_file, convert_state_dict_to_diffusers, convert_state_dict_to_peft, - convert_unet_state_dict_to_peft, delete_adapter_layers, get_adapter_name, get_peft_kwargs, + is_peft_version, logging, recurse_remove_peft_layers, scale_lora_layers, set_adapter_layers, set_weights_and_activate_adapters, ) -from .lora_conversion_utils import _convert_kohya_lora_to_diffusers, _maybe_map_sgm_blocks_to_diffusers +from .lora_conversion_utils import _convert_non_diffusers_lora_to_diffusers, _maybe_map_sgm_blocks_to_diffusers logger = logging.get_logger(__name__) @@ -49,7 +50,7 @@ UNET_NAME = "unet" TRANSFORMER_NAME = "transformer" -LORA_WEIGHT_NAME = "pytorch_lora_weights.bin" +LORA_WEIGHT_NAME = "pytorch_lora_weights.ckpt" LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors" LORA_DEPRECATION_MESSAGE = "You are using an old version of LoRA backend. This will be deprecated in the next releases in favor of PEFT make sure to install the latest PEFT and transformers packages in the future." # noqa: E501 @@ -127,7 +128,7 @@ def load_lora_weights( # First, ensure that the checkpoint is a compatible one and can be successfully loaded. state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) - is_correct_format = all("lora" in key for key in state_dict.keys()) + is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys()) if not is_correct_format: raise ValueError("Invalid LoRA checkpoint.") @@ -184,9 +185,9 @@ def lora_state_dict( force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. - resume_download (`bool`, *optional*, defaults to `False`): - Whether or not to resume downloading the model weights and configuration files. If set to `False`, any - incompletely downloaded files are deleted. + resume_download: + Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1 + of Diffusers. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. @@ -201,17 +202,14 @@ def lora_state_dict( allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. - mirror (`str`, *optional*): - Mirror source to resolve accessibility issues if you're downloading a model in China. We do not - guarantee the timeliness or safety of the source, and you should refer to the mirror site for more - information. - + weight_name (`str`, *optional*, defaults to None): + Name of the serialized state dict file. """ # Load the main state dict first which has the LoRA layers for either of # UNet and text encoder or both. cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) - resume_download = kwargs.pop("resume_download", False) + resume_download = kwargs.pop("resume_download", None) proxies = kwargs.pop("proxies", None) local_files_only = kwargs.pop("local_files_only", None) token = kwargs.pop("token", None) @@ -307,7 +305,7 @@ def lora_state_dict( if unet_config is not None: # use unet config to remap block numbers state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config) - state_dict, network_alphas = _convert_kohya_lora_to_diffusers(state_dict) + state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict) return state_dict, network_alphas @@ -363,77 +361,27 @@ def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None into the unet or prefixed with an additional `unet` which can be used to distinguish between text encoder lora layers. network_alphas (`Dict[str, float]`): - See `LoRALinearLayer` for more details. + The value of the network alpha used for stable learning and preventing underflow. This value has the + same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this + link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning). unet (`UNet2DConditionModel`): The UNet model to load the LoRA layers into. adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. """ - from mindone.diffusers._peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), # then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as # their prefixes. keys = list(state_dict.keys()) - - if all(key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) for key in keys): + only_text_encoder = all(key.startswith(cls.text_encoder_name) for key in keys) + if not only_text_encoder: # Load the layers corresponding to UNet. logger.info(f"Loading {cls.unet_name}.") - - unet_keys = [k for k in keys if k.startswith(cls.unet_name)] - state_dict = {k.replace(f"{cls.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys} - - if network_alphas is not None: - alpha_keys = [k for k in network_alphas.keys() if k.startswith(cls.unet_name)] - network_alphas = { - k.replace(f"{cls.unet_name}.", ""): v for k, v in network_alphas.items() if k in alpha_keys - } - - else: - # Otherwise, we're dealing with the old format. This means the `state_dict` should only - # contain the module names of the `unet` as its keys WITHOUT any prefix. - warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`." # noqa: E501 - logger.warning(warn_message) - - if len(state_dict.keys()) > 0: - if adapter_name in getattr(unet, "peft_config", {}): - raise ValueError( - f"Adapter name {adapter_name} already in use in the Unet - please select a new adapter name." - ) - - state_dict = convert_unet_state_dict_to_peft(state_dict) - - if network_alphas is not None: - # The alphas state dict have the same structure as Unet, thus we convert it to peft format using - # `convert_unet_state_dict_to_peft` method. - network_alphas = convert_unet_state_dict_to_peft(network_alphas) - - rank = {} - for key, val in state_dict.items(): - if "lora_B" in key: - rank[key] = val.shape[1] - - lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=True) - lora_config = LoraConfig(**lora_config_kwargs) - - # adapter_name - if adapter_name is None: - adapter_name = get_adapter_name(unet) - - inject_adapter_in_model(lora_config, unet, adapter_name=adapter_name) - incompatible_keys = set_peft_model_state_dict(unet, state_dict, adapter_name) - - if incompatible_keys is not None: - # check only for unexpected keys - unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) - if unexpected_keys: - logger.warning( - f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " - f" {unexpected_keys}. " - ) - - unet.load_attn_procs(state_dict, network_alphas=network_alphas, _pipeline=_pipeline) + unet.load_attn_procs( + state_dict, network_alphas=network_alphas, adapter_name=adapter_name, _pipeline=_pipeline + ) @classmethod def load_lora_into_text_encoder( @@ -512,6 +460,15 @@ def load_lora_into_text_encoder( } lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False) + if "use_dora" in lora_config_kwargs: + if lora_config_kwargs["use_dora"]: + if is_peft_version("<", "0.9.0"): + raise ValueError( + "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." + ) + else: + if is_peft_version("<", "0.9.0"): + lora_config_kwargs.pop("use_dora") lora_config = LoraConfig(**lora_config_kwargs) # adapter_name @@ -576,6 +533,13 @@ def load_lora_into_transformer(cls, state_dict, network_alphas, transformer, ada rank[key] = val.shape[1] lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict) + if "use_dora" in lora_config_kwargs: + if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"): + raise ValueError( + "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." + ) + else: + lora_config_kwargs.pop("use_dora") lora_config = LoraConfig(**lora_config_kwargs) # adapter_name @@ -727,10 +691,7 @@ def unload_lora_weights(self): ``` """ unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet - - recurse_remove_peft_layers(unet) - if hasattr(unet, "peft_config"): - del unet.peft_config + unet.unload_lora() # Safe to call the following regardless of LoRA. self._remove_text_encoder_monkey_patch() @@ -856,7 +817,7 @@ def set_adapters_for_text_encoder( self, adapter_names: Union[List[str], str], text_encoder: Optional["MSPreTrainedModel"] = None, # noqa: F821 - text_encoder_weights: List[float] = None, + text_encoder_weights: Optional[Union[float, List[float], List[None]]] = None, ): """ Sets the adapter layers for the text encoder. @@ -872,15 +833,20 @@ def set_adapters_for_text_encoder( """ def process_weights(adapter_names, weights): - if weights is None: - weights = [1.0] * len(adapter_names) - elif isinstance(weights, float): - weights = [weights] + # Expand weights into a list, one entry per adapter + # e.g. for 2 adapters: 7 -> [7,7] ; [3, None] -> [3, None] + if not isinstance(weights, list): + weights = [weights] * len(adapter_names) if len(adapter_names) != len(weights): raise ValueError( f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(weights)}" ) + + # Set None values to default of 1.0 + # e.g. [7,7] -> [7,7] ; [3, None] -> [3,1] + weights = [w if w is not None else 1.0 for w in weights] + return weights adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names @@ -923,17 +889,77 @@ def enable_lora_for_text_encoder(self, text_encoder: Optional["MSPreTrainedModel def set_adapters( self, adapter_names: Union[List[str], str], - adapter_weights: Optional[List[float]] = None, + adapter_weights: Optional[Union[float, Dict, List[float], List[Dict]]] = None, ): + adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names + + adapter_weights = copy.deepcopy(adapter_weights) + + # Expand weights into a list, one entry per adapter + if not isinstance(adapter_weights, list): + adapter_weights = [adapter_weights] * len(adapter_names) + + if len(adapter_names) != len(adapter_weights): + raise ValueError( + f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(adapter_weights)}" + ) + + # Decompose weights into weights for unet, text_encoder and text_encoder_2 + unet_lora_weights, text_encoder_lora_weights, text_encoder_2_lora_weights = [], [], [] + + list_adapters = self.get_list_adapters() # eg {"unet": ["adapter1", "adapter2"], "text_encoder": ["adapter2"]} + all_adapters = { + adapter for adapters in list_adapters.values() for adapter in adapters + } # eg ["adapter1", "adapter2"] + invert_list_adapters = { + adapter: [part for part, adapters in list_adapters.items() if adapter in adapters] + for adapter in all_adapters + } # eg {"adapter1": ["unet"], "adapter2": ["unet", "text_encoder"]} + + for adapter_name, weights in zip(adapter_names, adapter_weights): + if isinstance(weights, dict): + unet_lora_weight = weights.pop("unet", None) + text_encoder_lora_weight = weights.pop("text_encoder", None) + text_encoder_2_lora_weight = weights.pop("text_encoder_2", None) + + if len(weights) > 0: + raise ValueError( + f"Got invalid key '{weights.keys()}' in lora weight dict for adapter {adapter_name}." + ) + + if text_encoder_2_lora_weight is not None and not hasattr(self, "text_encoder_2"): + logger.warning( + "Lora weight dict contains text_encoder_2 weights but will be ignored because pipeline does not have text_encoder_2." + ) + + # warn if adapter doesn't have parts specified by adapter_weights + for part_weight, part_name in zip( + [unet_lora_weight, text_encoder_lora_weight, text_encoder_2_lora_weight], + ["unet", "text_encoder", "text_encoder_2"], + ): + if part_weight is not None and part_name not in invert_list_adapters[adapter_name]: + logger.warning( + f"Lora weight dict for adapter '{adapter_name}' contains {part_name}, but this will be ignored because {adapter_name} does not contain weights for {part_name}. Valid parts for {adapter_name} are: {invert_list_adapters[adapter_name]}." # noqa E501 + ) + + else: + unet_lora_weight = weights + text_encoder_lora_weight = weights + text_encoder_2_lora_weight = weights + + unet_lora_weights.append(unet_lora_weight) + text_encoder_lora_weights.append(text_encoder_lora_weight) + text_encoder_2_lora_weights.append(text_encoder_2_lora_weight) + unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet # Handle the UNET - unet.set_adapters(adapter_names, adapter_weights) + unet.set_adapters(adapter_names, unet_lora_weights) # Handle the Text Encoder if hasattr(self, "text_encoder"): - self.set_adapters_for_text_encoder(adapter_names, self.text_encoder, adapter_weights) + self.set_adapters_for_text_encoder(adapter_names, self.text_encoder, text_encoder_lora_weights) if hasattr(self, "text_encoder_2"): - self.set_adapters_for_text_encoder(adapter_names, self.text_encoder_2, adapter_weights) + self.set_adapters_for_text_encoder(adapter_names, self.text_encoder_2, text_encoder_2_lora_weights) def disable_lora(self): # Disable unet adapters @@ -1071,7 +1097,7 @@ def load_lora_weights( unet_config=self.unet.config, **kwargs, ) - is_correct_format = all("lora" in key for key in state_dict.keys()) + is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys()) if not is_correct_format: raise ValueError("Invalid LoRA checkpoint.") @@ -1125,6 +1151,9 @@ def save_lora_weights( text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text encoder LoRA state dict because it comes from 🤗 Transformers. + text_encoder_2_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the `text_encoder_2`. Must explicitly pass the text + encoder LoRA state dict because it comes from 🤗 Transformers. is_main_process (`bool`, *optional*, defaults to `True`): Whether the process calling this is the main process or not. Useful during distributed training and you need to call this function on all processes. In this case, set `is_main_process=True` only on the main @@ -1151,8 +1180,10 @@ def pack_weights(layers, prefix): if unet_lora_layers: state_dict.update(pack_weights(unet_lora_layers, "unet")) - if text_encoder_lora_layers and text_encoder_2_lora_layers: + if text_encoder_lora_layers: state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder")) + + if text_encoder_2_lora_layers: state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2")) cls.write_lora_layers( @@ -1215,7 +1246,7 @@ def load_lora_weights( # First, ensure that the checkpoint is a compatible one and can be successfully loaded. state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) - is_correct_format = all("lora" in key for key in state_dict.keys()) + is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys()) if not is_correct_format: raise ValueError("Invalid LoRA checkpoint.") @@ -1389,6 +1420,13 @@ def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, rank[key] = val.shape[1] lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=None, peft_state_dict=state_dict) + if "use_dora" in lora_config_kwargs: + if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"): + raise ValueError( + "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." + ) + else: + lora_config_kwargs.pop("use_dora") lora_config = LoraConfig(**lora_config_kwargs) # adapter_name diff --git a/mindone/diffusers/loaders/lora_conversion_utils.py b/mindone/diffusers/loaders/lora_conversion_utils.py index bd03988dbc..1815ff8a52 100644 --- a/mindone/diffusers/loaders/lora_conversion_utils.py +++ b/mindone/diffusers/loaders/lora_conversion_utils.py @@ -14,7 +14,7 @@ import re -from ..utils import logging +from ..utils import is_peft_version, logging logger = logging.get_logger(__name__) @@ -122,153 +122,108 @@ def _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config, delimiter="_", b return new_state_dict -def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_name="text_encoder"): +def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_name="text_encoder"): + """ + Converts a non-Diffusers LoRA state dict to a Diffusers compatible state dict. + + Args: + state_dict (`dict`): The state dict to convert. + unet_name (`str`, optional): The name of the U-Net module in the Diffusers model. Defaults to "unet". + text_encoder_name (`str`, optional): The name of the text encoder module in the Diffusers model. Defaults to + "text_encoder". + + Returns: + `tuple`: A tuple containing the converted state dict and a dictionary of alphas. + """ unet_state_dict = {} te_state_dict = {} te2_state_dict = {} network_alphas = {} - # every down weight has a corresponding up weight and potentially an alpha weight - lora_keys = [k for k in state_dict.keys() if k.endswith("lora_down.weight")] - for key in lora_keys: + # Check for DoRA-enabled LoRAs. + dora_present_in_unet = any("dora_scale" in k and "lora_unet_" in k for k in state_dict) + dora_present_in_te = any("dora_scale" in k and ("lora_te_" in k or "lora_te1_" in k) for k in state_dict) + dora_present_in_te2 = any("dora_scale" in k and "lora_te2_" in k for k in state_dict) + if dora_present_in_unet or dora_present_in_te or dora_present_in_te2: + if is_peft_version("<", "0.9.0"): + raise ValueError( + "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." + ) + + # Iterate over all LoRA weights. + all_lora_keys = list(state_dict.keys()) + for key in all_lora_keys: + if not key.endswith("lora_down.weight"): + continue + + # Extract LoRA name. lora_name = key.split(".")[0] + + # Find corresponding up weight and alpha. lora_name_up = lora_name + ".lora_up.weight" lora_name_alpha = lora_name + ".alpha" + # Handle U-Net LoRAs. if lora_name.startswith("lora_unet_"): - diffusers_name = key.replace("lora_unet_", "").replace("_", ".") + diffusers_name = _convert_unet_lora_key(key) - if "input.blocks" in diffusers_name: - diffusers_name = diffusers_name.replace("input.blocks", "down_blocks") - else: - diffusers_name = diffusers_name.replace("down.blocks", "down_blocks") + # Store down and up weights. + unet_state_dict[diffusers_name] = state_dict.pop(key) + unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) - if "middle.block" in diffusers_name: - diffusers_name = diffusers_name.replace("middle.block", "mid_block") - else: - diffusers_name = diffusers_name.replace("mid.block", "mid_block") - if "output.blocks" in diffusers_name: - diffusers_name = diffusers_name.replace("output.blocks", "up_blocks") - else: - diffusers_name = diffusers_name.replace("up.blocks", "up_blocks") - - diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks") - diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora") - diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora") - diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora") - diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora") - diffusers_name = diffusers_name.replace("proj.in", "proj_in") - diffusers_name = diffusers_name.replace("proj.out", "proj_out") - diffusers_name = diffusers_name.replace("emb.layers", "time_emb_proj") - - # SDXL specificity. - if "emb" in diffusers_name and "time.emb.proj" not in diffusers_name: - pattern = r"\.\d+(?=\D*$)" - diffusers_name = re.sub(pattern, "", diffusers_name, count=1) - if ".in." in diffusers_name: - diffusers_name = diffusers_name.replace("in.layers.2", "conv1") - if ".out." in diffusers_name: - diffusers_name = diffusers_name.replace("out.layers.3", "conv2") - if "downsamplers" in diffusers_name or "upsamplers" in diffusers_name: - diffusers_name = diffusers_name.replace("op", "conv") - if "skip" in diffusers_name: - diffusers_name = diffusers_name.replace("skip.connection", "conv_shortcut") - - # LyCORIS specificity. - if "time.emb.proj" in diffusers_name: - diffusers_name = diffusers_name.replace("time.emb.proj", "time_emb_proj") - if "conv.shortcut" in diffusers_name: - diffusers_name = diffusers_name.replace("conv.shortcut", "conv_shortcut") - - # General coverage. - if "transformer_blocks" in diffusers_name: - if "attn1" in diffusers_name or "attn2" in diffusers_name: - diffusers_name = diffusers_name.replace("attn1", "attn1.processor") - diffusers_name = diffusers_name.replace("attn2", "attn2.processor") - unet_state_dict[diffusers_name] = state_dict.pop(key) - unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) - elif "ff" in diffusers_name: - unet_state_dict[diffusers_name] = state_dict.pop(key) - unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) - elif any(_ in diffusers_name for _ in ("proj_in", "proj_out")): - unet_state_dict[diffusers_name] = state_dict.pop(key) - unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) - else: - unet_state_dict[diffusers_name] = state_dict.pop(key) - unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) - - elif lora_name.startswith("lora_te_"): - diffusers_name = key.replace("lora_te_", "").replace("_", ".") - diffusers_name = diffusers_name.replace("text.model", "text_model") - diffusers_name = diffusers_name.replace("self.attn", "self_attn") - diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora") - diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora") - diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora") - diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora") - if "self_attn" in diffusers_name: - te_state_dict[diffusers_name] = state_dict.pop(key) - te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) - elif "mlp" in diffusers_name: - # Be aware that this is the new diffusers convention and the rest of the code might - # not utilize it yet. - diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.") - te_state_dict[diffusers_name] = state_dict.pop(key) - te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) + # Store DoRA scale if present. + if dora_present_in_unet: + dora_scale_key_to_replace = "_lora.down." if "_lora.down." in diffusers_name else ".lora.down." + unet_state_dict[ + diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.") + ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale")) - # (sayakpaul): Duplicate code. Needs to be cleaned. - elif lora_name.startswith("lora_te1_"): - diffusers_name = key.replace("lora_te1_", "").replace("_", ".") - diffusers_name = diffusers_name.replace("text.model", "text_model") - diffusers_name = diffusers_name.replace("self.attn", "self_attn") - diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora") - diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora") - diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora") - diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora") - if "self_attn" in diffusers_name: - te_state_dict[diffusers_name] = state_dict.pop(key) - te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) - elif "mlp" in diffusers_name: - # Be aware that this is the new diffusers convention and the rest of the code might - # not utilize it yet. - diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.") + # Handle text encoder LoRAs. + elif lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")): + diffusers_name = _convert_text_encoder_lora_key(key, lora_name) + + # Store down and up weights for te or te2. + if lora_name.startswith(("lora_te_", "lora_te1_")): te_state_dict[diffusers_name] = state_dict.pop(key) te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) - - # (sayakpaul): Duplicate code. Needs to be cleaned. - elif lora_name.startswith("lora_te2_"): - diffusers_name = key.replace("lora_te2_", "").replace("_", ".") - diffusers_name = diffusers_name.replace("text.model", "text_model") - diffusers_name = diffusers_name.replace("self.attn", "self_attn") - diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora") - diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora") - diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora") - diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora") - if "self_attn" in diffusers_name: - te2_state_dict[diffusers_name] = state_dict.pop(key) - te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) - elif "mlp" in diffusers_name: - # Be aware that this is the new diffusers convention and the rest of the code might - # not utilize it yet. - diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.") + else: te2_state_dict[diffusers_name] = state_dict.pop(key) te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) - # Rename the alphas so that they can be mapped appropriately. + # Store DoRA scale if present. + if dora_present_in_te or dora_present_in_te2: + dora_scale_key_to_replace_te = ( + "_lora.down." if "_lora.down." in diffusers_name else ".lora_linear_layer." + ) + if lora_name.startswith(("lora_te_", "lora_te1_")): + te_state_dict[ + diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.") + ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale")) + elif lora_name.startswith("lora_te2_"): + te2_state_dict[ + diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.") + ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale")) + + # Store alpha if present. if lora_name_alpha in state_dict: alpha = state_dict.pop(lora_name_alpha).item() - if lora_name_alpha.startswith("lora_unet_"): - prefix = "unet." - elif lora_name_alpha.startswith(("lora_te_", "lora_te1_")): - prefix = "text_encoder." - else: - prefix = "text_encoder_2." - new_name = prefix + diffusers_name.split(".lora.")[0] + ".alpha" - network_alphas.update({new_name: alpha}) + # When alpha comes from a Tensor with dtype bfloat16, the `item()` returns an instance + # of ml_dtypes.bfloat16 which couldn't be computed with float, int or tensors directly. + # Therefore we cast it to float to enable binary operation with others. + + # TODO: mindspore_bf16_tensor.item() should return to a python built-in float natively, + # push mindspore to do it. + if not isinstance(alpha, (int, float)): + alpha = float(alpha) + network_alphas.update(_get_alpha_name(lora_name_alpha, diffusers_name, alpha)) + # Check if any keys remain. if len(state_dict) > 0: - raise ValueError(f"The following keys have not been correctly be renamed: \n\n {', '.join(state_dict.keys())}") + raise ValueError(f"The following keys have not been correctly renamed: \n\n {', '.join(state_dict.keys())}") - logger.info("Kohya-style checkpoint detected.") + logger.info("Non-diffusers checkpoint detected.") + + # Construct final state dict. unet_state_dict = {f"{unet_name}.{module_name}": params for module_name, params in unet_state_dict.items()} te_state_dict = {f"{text_encoder_name}.{module_name}": params for module_name, params in te_state_dict.items()} te2_state_dict = ( @@ -281,3 +236,100 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_ new_state_dict = {**unet_state_dict, **te_state_dict} return new_state_dict, network_alphas + + +def _convert_unet_lora_key(key): + """ + Converts a U-Net LoRA key to a Diffusers compatible key. + """ + diffusers_name = key.replace("lora_unet_", "").replace("_", ".") + + # Replace common U-Net naming patterns. + diffusers_name = diffusers_name.replace("input.blocks", "down_blocks") + diffusers_name = diffusers_name.replace("down.blocks", "down_blocks") + diffusers_name = diffusers_name.replace("middle.block", "mid_block") + diffusers_name = diffusers_name.replace("mid.block", "mid_block") + diffusers_name = diffusers_name.replace("output.blocks", "up_blocks") + diffusers_name = diffusers_name.replace("up.blocks", "up_blocks") + diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks") + diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora") + diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora") + diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora") + diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora") + diffusers_name = diffusers_name.replace("proj.in", "proj_in") + diffusers_name = diffusers_name.replace("proj.out", "proj_out") + diffusers_name = diffusers_name.replace("emb.layers", "time_emb_proj") + + # SDXL specific conversions. + if "emb" in diffusers_name and "time.emb.proj" not in diffusers_name: + pattern = r"\.\d+(?=\D*$)" + diffusers_name = re.sub(pattern, "", diffusers_name, count=1) + if ".in." in diffusers_name: + diffusers_name = diffusers_name.replace("in.layers.2", "conv1") + if ".out." in diffusers_name: + diffusers_name = diffusers_name.replace("out.layers.3", "conv2") + if "downsamplers" in diffusers_name or "upsamplers" in diffusers_name: + diffusers_name = diffusers_name.replace("op", "conv") + if "skip" in diffusers_name: + diffusers_name = diffusers_name.replace("skip.connection", "conv_shortcut") + + # LyCORIS specific conversions. + if "time.emb.proj" in diffusers_name: + diffusers_name = diffusers_name.replace("time.emb.proj", "time_emb_proj") + if "conv.shortcut" in diffusers_name: + diffusers_name = diffusers_name.replace("conv.shortcut", "conv_shortcut") + + # General conversions. + if "transformer_blocks" in diffusers_name: + if "attn1" in diffusers_name or "attn2" in diffusers_name: + diffusers_name = diffusers_name.replace("attn1", "attn1.processor") + diffusers_name = diffusers_name.replace("attn2", "attn2.processor") + elif "ff" in diffusers_name: + pass + elif any(key in diffusers_name for key in ("proj_in", "proj_out")): + pass + else: + pass + + return diffusers_name + + +def _convert_text_encoder_lora_key(key, lora_name): + """ + Converts a text encoder LoRA key to a Diffusers compatible key. + """ + if lora_name.startswith(("lora_te_", "lora_te1_")): + key_to_replace = "lora_te_" if lora_name.startswith("lora_te_") else "lora_te1_" + else: + key_to_replace = "lora_te2_" + + diffusers_name = key.replace(key_to_replace, "").replace("_", ".") + diffusers_name = diffusers_name.replace("text.model", "text_model") + diffusers_name = diffusers_name.replace("self.attn", "self_attn") + diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora") + diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora") + diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora") + diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora") + diffusers_name = diffusers_name.replace("text.projection", "text_projection") + + if "self_attn" in diffusers_name or "text_projection" in diffusers_name: + pass + elif "mlp" in diffusers_name: + # Be aware that this is the new diffusers convention and the rest of the code might + # not utilize it yet. + diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.") + return diffusers_name + + +def _get_alpha_name(lora_name_alpha, diffusers_name, alpha): + """ + Gets the correct alpha name for the Diffusers model. + """ + if lora_name_alpha.startswith("lora_unet_"): + prefix = "unet." + elif lora_name_alpha.startswith(("lora_te_", "lora_te1_")): + prefix = "text_encoder." + else: + prefix = "text_encoder_2." + new_name = prefix + diffusers_name.split(".lora.")[0] + ".alpha" + return {new_name: alpha} diff --git a/mindone/diffusers/loaders/peft.py b/mindone/diffusers/loaders/peft.py index 273f5539bf..3f957f61ec 100644 --- a/mindone/diffusers/loaders/peft.py +++ b/mindone/diffusers/loaders/peft.py @@ -18,7 +18,8 @@ class PeftAdapterMixin: """ A class containing all functions for loading and using adapters weights that are supported in PEFT library. For - more details about adapters and injecting them in a transformer-based model, check out the PEFT [documentation](https://huggingface.co/docs/peft/index). + more details about adapters and injecting them in a transformer-based model, check out the PEFT + [documentation](https://huggingface.co/docs/peft/index). Install the latest version of PEFT, and use this mixin to: @@ -123,8 +124,8 @@ def disable_adapters(self) -> None: def enable_adapters(self) -> None: """ - Enable adapters that are attached to the model. The model uses `self.active_adapters()` to retrieve the - list of adapters to enable. + Enable adapters that are attached to the model. The model uses `self.active_adapters()` to retrieve the list of + adapters to enable. If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT [documentation](https://huggingface.co/docs/peft). diff --git a/mindone/diffusers/loaders/single_file.py b/mindone/diffusers/loaders/single_file.py index 55daff3617..2e78a3c048 100644 --- a/mindone/diffusers/loaders/single_file.py +++ b/mindone/diffusers/loaders/single_file.py @@ -11,142 +11,253 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import importlib +import inspect +import os -from huggingface_hub.utils import validate_hf_hub_args -from transformers import AutoFeatureExtractor +from huggingface_hub import snapshot_download +from huggingface_hub.utils import LocalEntryNotFoundError, validate_hf_hub_args +from packaging import version -from ..utils import logging +from mindspore import nn + +from ..utils import deprecate, is_transformers_available, logging, maybe_import_module_in_mindone from .single_file_utils import ( - create_diffusers_unet_model_from_ldm, - create_diffusers_vae_model_from_ldm, - create_scheduler_from_ldm, - create_text_encoders_and_tokenizers_from_ldm, - fetch_ldm_config_and_checkpoint, - infer_model_type, + SingleFileComponentError, + _is_model_weights_in_cached_folder, + _legacy_load_clip_tokenizer, + _legacy_load_safety_checker, + _legacy_load_scheduler, + create_diffusers_clip_model_from_ldm, + create_diffusers_t5_model_from_checkpoint, + fetch_diffusers_config, + fetch_original_config, + is_clip_model_in_single_file, + is_t5_in_single_file, + load_single_file_checkpoint, ) logger = logging.get_logger(__name__) -# Pipelines that support the SDXL Refiner checkpoint -REFINER_PIPELINES = [ - "StableDiffusionXLImg2ImgPipeline", - "StableDiffusionXLInpaintPipeline", - "StableDiffusionXLControlNetImg2ImgPipeline", -] +# Legacy behaviour. `from_single_file` does not load the safety checker unless explicitly provided +SINGLE_FILE_OPTIONAL_COMPONENTS = ["safety_checker"] + + +if is_transformers_available(): + import transformers + from transformers import PreTrainedTokenizer + from mindone.transformers import MSPreTrainedModel as PreTrainedModel -def build_sub_model_components( - pipeline_components, - pipeline_class_name, - component_name, - original_config, + +def load_single_file_sub_model( + library_name, + class_name, + name, checkpoint, + pipelines, + is_pipeline_module, + cached_model_config_path, + original_config=None, local_files_only=False, - load_safety_checker=False, - model_type=None, - image_size=None, mindspore_dtype=None, + is_legacy_loading=False, **kwargs, ): - if component_name in pipeline_components: - return {} - if component_name == "unet": - num_in_channels = kwargs.pop("num_in_channels", None) - upcast_attention = kwargs.pop("upcast_attention", None) - - unet_components = create_diffusers_unet_model_from_ldm( - pipeline_class_name, - original_config, - checkpoint, - num_in_channels=num_in_channels, - image_size=image_size, + if is_pipeline_module: + pipeline_module = getattr(pipelines, library_name) + class_obj = getattr(pipeline_module, class_name) + else: + # else we just import it from the library. + library = maybe_import_module_in_mindone(library_name) + if hasattr(library, class_name): + class_obj = getattr(library, class_name) + else: + library = maybe_import_module_in_mindone(library_name, force_original=True) + class_obj = getattr(library, class_name) + + if is_transformers_available(): + transformers_version = version.parse(version.parse(transformers.__version__).base_version) + else: + transformers_version = "N/A" + + is_transformers_model = issubclass(class_obj, PreTrainedModel) + is_tokenizer = ( + is_transformers_available() + and issubclass(class_obj, PreTrainedTokenizer) + and transformers_version >= version.parse("4.20.0") + ) + + diffusers_module = importlib.import_module("mindone.diffusers") + is_diffusers_single_file_model = issubclass(class_obj, diffusers_module.FromOriginalModelMixin) + is_diffusers_model = issubclass(class_obj, diffusers_module.ModelMixin) + is_diffusers_scheduler = issubclass(class_obj, diffusers_module.SchedulerMixin) + + if is_diffusers_single_file_model: + load_method = getattr(class_obj, "from_single_file") + + # We cannot provide two different config options to the `from_single_file` method + # Here we have to ignore loading the config from `cached_model_config_path` if `original_config` is provided + if original_config: + cached_model_config_path = None + + loaded_sub_model = load_method( + pretrained_model_link_or_path_or_dict=checkpoint, + original_config=original_config, + config=cached_model_config_path, + subfolder=name, mindspore_dtype=mindspore_dtype, - model_type=model_type, - upcast_attention=upcast_attention, + local_files_only=local_files_only, + **kwargs, ) - return unet_components - if component_name == "vae": - scaling_factor = kwargs.get("scaling_factor", None) - vae_components = create_diffusers_vae_model_from_ldm( - pipeline_class_name, - original_config, - checkpoint, - image_size, - scaling_factor, - mindspore_dtype, - model_type=model_type, - ) - return vae_components - - if component_name == "scheduler": - scheduler_type = kwargs.get("scheduler_type", "ddim") - prediction_type = kwargs.get("prediction_type", None) - - scheduler_components = create_scheduler_from_ldm( - pipeline_class_name, - original_config, - checkpoint, - scheduler_type=scheduler_type, - prediction_type=prediction_type, - model_type=model_type, + elif is_transformers_model and is_clip_model_in_single_file(class_obj, checkpoint): + loaded_sub_model = create_diffusers_clip_model_from_ldm( + class_obj, + checkpoint=checkpoint, + config=cached_model_config_path, + subfolder=name, + mindspore_dtype=mindspore_dtype, + local_files_only=local_files_only, + is_legacy_loading=is_legacy_loading, ) - return scheduler_components - - if component_name in ["text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2"]: - text_encoder_components = create_text_encoders_and_tokenizers_from_ldm( - original_config, - checkpoint, - model_type=model_type, - local_files_only=local_files_only, + elif is_transformers_model and is_t5_in_single_file(checkpoint): + loaded_sub_model = create_diffusers_t5_model_from_checkpoint( + class_obj, + checkpoint=checkpoint, + config=cached_model_config_path, + subfolder=name, mindspore_dtype=mindspore_dtype, + local_files_only=local_files_only, ) - return text_encoder_components - if component_name == "safety_checker": - if load_safety_checker: - from ..pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker + elif is_tokenizer and is_legacy_loading: + loaded_sub_model = _legacy_load_clip_tokenizer( + class_obj, checkpoint=checkpoint, config=cached_model_config_path, local_files_only=local_files_only + ) - safety_checker = StableDiffusionSafetyChecker.from_pretrained( - "CompVis/stable-diffusion-safety-checker", - local_files_only=local_files_only, - mindspore_dtype=mindspore_dtype, - ) - else: - safety_checker = None - return {"safety_checker": safety_checker} + elif is_diffusers_scheduler and is_legacy_loading: + loaded_sub_model = _legacy_load_scheduler( + class_obj, checkpoint=checkpoint, component_name=name, original_config=original_config, **kwargs + ) - if component_name == "feature_extractor": - if load_safety_checker: - feature_extractor = AutoFeatureExtractor.from_pretrained( - "CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only + else: + if not hasattr(class_obj, "from_pretrained"): + raise ValueError( + ( + f"The component {class_obj.__name__} cannot be loaded as it does not seem to have" + " a supported loading method." + ) ) - else: - feature_extractor = None - return {"feature_extractor": feature_extractor} - - return - -def set_additional_components( - pipeline_class_name, - original_config, - checkpoint=None, - model_type=None, -): - components = {} - if pipeline_class_name in REFINER_PIPELINES: - model_type = infer_model_type(original_config, checkpoint=checkpoint, model_type=model_type) - is_refiner = model_type == "SDXL-Refiner" - components.update( + loading_kwargs = {} + loading_kwargs.update( { - "requires_aesthetics_score": is_refiner, - "force_zeros_for_empty_prompt": False if is_refiner else True, + "pretrained_model_name_or_path": cached_model_config_path, + "subfolder": name, + "local_files_only": local_files_only, } ) - return components + # Schedulers and Tokenizers don't make use of mindspore_dtype + # Skip passing it to those objects + if issubclass(class_obj, nn.Cell): + loading_kwargs.update({"mindspore_dtype": mindspore_dtype}) + + if is_diffusers_model or is_transformers_model: + if not _is_model_weights_in_cached_folder(cached_model_config_path, name): + raise SingleFileComponentError( + f"Failed to load {class_name}. Weights for this component appear to be missing in the checkpoint." + ) + + load_method = getattr(class_obj, "from_pretrained") + loaded_sub_model = load_method(**loading_kwargs) + + return loaded_sub_model + + +def _map_component_types_to_config_dict(component_types): + diffusers_module = importlib.import_module("mindone.diffusers") + config_dict = {} + component_types.pop("self", None) + + if is_transformers_available(): + transformers_version = version.parse(version.parse(transformers.__version__).base_version) + else: + transformers_version = "N/A" + + for component_name, component_value in component_types.items(): + is_diffusers_model = issubclass(component_value[0], diffusers_module.ModelMixin) + is_scheduler_enum = component_value[0].__name__ == "KarrasDiffusionSchedulers" + is_scheduler = issubclass(component_value[0], diffusers_module.SchedulerMixin) + + is_transformers_model = is_transformers_available() and issubclass(component_value[0], PreTrainedModel) + is_transformers_tokenizer = ( + is_transformers_available() + and issubclass(component_value[0], PreTrainedTokenizer) + and transformers_version >= version.parse("4.20.0") + ) + + if is_diffusers_model and component_name not in SINGLE_FILE_OPTIONAL_COMPONENTS: + config_dict[component_name] = ["diffusers", component_value[0].__name__] + + elif is_scheduler_enum or is_scheduler: + if is_scheduler_enum: + # Since we cannot fetch a scheduler config from the hub, we default to DDIMScheduler + # if the type hint is a KarrassDiffusionSchedulers enum + config_dict[component_name] = ["diffusers", "DDIMScheduler"] + + elif is_scheduler: + config_dict[component_name] = ["diffusers", component_value[0].__name__] + + elif ( + is_transformers_model or is_transformers_tokenizer + ) and component_name not in SINGLE_FILE_OPTIONAL_COMPONENTS: + config_dict[component_name] = ["transformers", component_value[0].__name__] + + else: + config_dict[component_name] = [None, None] + + return config_dict + + +def _infer_pipeline_config_dict(pipeline_class): + parameters = inspect.signature(pipeline_class.__init__).parameters + required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty} + component_types = pipeline_class._get_signature_types() + + # Ignore parameters that are not required for the pipeline + component_types = {k: v for k, v in component_types.items() if k in required_parameters} + config_dict = _map_component_types_to_config_dict(component_types) + + return config_dict + + +def _download_diffusers_model_config_from_hub( + pretrained_model_name_or_path, + cache_dir, + revision, + proxies, + force_download=None, + resume_download=None, + local_files_only=None, + token=None, +): + allow_patterns = ["**/*.json", "*.json", "*.txt", "**/*.txt", "**/*.model"] + cached_model_path = snapshot_download( + pretrained_model_name_or_path, + cache_dir=cache_dir, + revision=revision, + proxies=proxies, + force_download=force_download, + resume_download=resume_download, + local_files_only=local_files_only, + token=token, + allow_patterns=allow_patterns, + ) + + return cached_model_path class FromSingleFileMixin: @@ -175,9 +286,9 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs): cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory where a downloaded pretrained model configuration is cached if the standard cache is not used. - resume_download (`bool`, *optional*, defaults to `False`): - Whether or not to resume downloading the model weights and configuration files. If set to `False`, any - incompletely downloaded files are deleted. + resume_download: + Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1 + of Diffusers. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. @@ -193,23 +304,12 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs): original_config_file (`str`, *optional*): The path to the original config file that was used to train the model. If not provided, the config file will be inferred from the checkpoint file. - model_type (`str`, *optional*): - The type of model to load. If not provided, the model type will be inferred from the checkpoint file. - image_size (`int`, *optional*): - The size of the image output. It's used to configure the `sample_size` parameter of the UNet and VAE model. - load_safety_checker (`bool`, *optional*, defaults to `False`): - Whether to load the safety checker model or not. - By default, the safety checker is not loaded unless a `safety_checker` component is passed to the `kwargs`. - num_in_channels (`int`, *optional*): - Specify the number of input channels for the UNet model. Read more about how to configure UNet model with this parameter - [here](https://huggingface.co/docs/diffusers/training/adapt_a_model#configure-unet2dconditionmodel-parameters). - scaling_factor (`float`, *optional*): - The scaling factor to use for the VAE model. If not provided, it is inferred from the config file first. - If the scaling factor is not found in the config file, the default value 0.18215 is used. - scheduler_type (`str`, *optional*): - The type of scheduler to load. If not provided, the scheduler type will be inferred from the checkpoint file. - prediction_type (`str`, *optional*): - The type of prediction to load. If not provided, the prediction type will be inferred from the checkpoint file. + config (`str`, *optional*): + Can be either: + - A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline + hosted on the Hub. + - A path to a *directory* (for example `./my_pipeline_directory/`) containing the pipeline + component configs in Diffusers format. kwargs (remaining dictionary of keyword arguments, *optional*): Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline class). The overwritten components are passed directly to the pipelines `__init__` method. See example @@ -229,9 +329,21 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs): >>> # file is downloaded under ./v1-5-pruned-emaonly.ckpt >>> pipeline = StableDiffusionPipeline.from_single_file("./v1-5-pruned-emaonly") ``` + """ original_config_file = kwargs.pop("original_config_file", None) - resume_download = kwargs.pop("resume_download", False) + config = kwargs.pop("config", None) + original_config = kwargs.pop("original_config", None) + + if original_config_file is not None: + deprecation_message = ( + "`original_config_file` argument is deprecated and will be removed in future versions." + "please use the `original_config` argument instead." + ) + deprecate("original_config_file", "1.0.0", deprecation_message) + original_config = original_config_file + + resume_download = kwargs.pop("resume_download", None) force_download = kwargs.pop("force_download", False) proxies = kwargs.pop("proxies", None) token = kwargs.pop("token", None) @@ -240,68 +352,198 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs): revision = kwargs.pop("revision", None) mindspore_dtype = kwargs.pop("mindspore_dtype", None) - class_name = cls.__name__ + is_legacy_loading = False - original_config, checkpoint = fetch_ldm_config_and_checkpoint( - pretrained_model_link_or_path=pretrained_model_link_or_path, - class_name=class_name, - original_config_file=original_config_file, + # We shouldn't allow configuring individual models components through a Pipeline creation method + # These model kwargs should be deprecated + scaling_factor = kwargs.get("scaling_factor", None) + if scaling_factor is not None: + deprecation_message = ( + "Passing the `scaling_factor` argument to `from_single_file is deprecated " + "and will be ignored in future versions." + ) + deprecate("scaling_factor", "1.0.0", deprecation_message) + + if original_config is not None: + original_config = fetch_original_config(original_config, local_files_only=local_files_only) + + from ..pipelines.pipeline_utils import _get_pipeline_class + + pipeline_class = _get_pipeline_class(cls, config=None) + + checkpoint = load_single_file_checkpoint( + pretrained_model_link_or_path, resume_download=resume_download, force_download=force_download, proxies=proxies, token=token, - revision=revision, - local_files_only=local_files_only, cache_dir=cache_dir, + local_files_only=local_files_only, + revision=revision, ) - from ..pipelines.pipeline_utils import _get_pipeline_class + if config is None: + config = fetch_diffusers_config(checkpoint) + default_pretrained_model_config_name = config["pretrained_model_name_or_path"] + else: + default_pretrained_model_config_name = config + + if not os.path.isdir(default_pretrained_model_config_name): + # Provided config is a repo_id + if default_pretrained_model_config_name.count("/") > 1: + raise ValueError( + f'The provided config "{config}"' + " is neither a valid local path nor a valid repo id. Please check the parameter." + ) + try: + # Attempt to download the config files for the pipeline + cached_model_config_path = _download_diffusers_model_config_from_hub( + default_pretrained_model_config_name, + cache_dir=cache_dir, + revision=revision, + proxies=proxies, + force_download=force_download, + resume_download=resume_download, + local_files_only=local_files_only, + token=token, + ) + config_dict = pipeline_class.load_config(cached_model_config_path) + + except LocalEntryNotFoundError: + # `local_files_only=True` but a local diffusers format model config is not available in the cache + # If `original_config` is not provided, we need override `local_files_only` to False + # to fetch the config files from the hub so that we have a way + # to configure the pipeline components. + + if original_config is None: + logger.warning( + "`local_files_only` is True but no local configs were found for this checkpoint.\n" + "Attempting to download the necessary config files for this pipeline.\n" + ) + cached_model_config_path = _download_diffusers_model_config_from_hub( + default_pretrained_model_config_name, + cache_dir=cache_dir, + revision=revision, + proxies=proxies, + force_download=force_download, + resume_download=resume_download, + local_files_only=False, + token=token, + ) + config_dict = pipeline_class.load_config(cached_model_config_path) + + else: + # For backwards compatibility + # If `original_config` is provided, then we need to assume we are using legacy loading for pipeline components + logger.warning( + "Detected legacy `from_single_file` loading behavior. Attempting to create the pipeline based on inferred components.\n" + "This may lead to errors if the model components are not correctly inferred. \n" + "To avoid this warning, please explicity pass the `config` argument to `from_single_file` with a path to a local diffusers model repo \n" # noqa E501 + "e.g. `from_single_file(, config=) \n" + "or run `from_single_file` with `local_files_only=False` first to update the local cache directory with " + "the necessary config files.\n" + ) + is_legacy_loading = True + cached_model_config_path = None + + config_dict = _infer_pipeline_config_dict(pipeline_class) + config_dict["_class_name"] = pipeline_class.__name__ - pipeline_class = _get_pipeline_class( - cls, - config=None, - cache_dir=cache_dir, - ) + else: + # Provided config is a path to a local directory attempt to load directly. + cached_model_config_path = default_pretrained_model_config_name + config_dict = pipeline_class.load_config(cached_model_config_path) + + # pop out "_ignore_files" as it is only needed for download + config_dict.pop("_ignore_files", None) - expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class) + expected_modules, optional_kwargs = pipeline_class._get_signature_keys(cls) passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs} passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs} - model_type = kwargs.pop("model_type", None) - image_size = kwargs.pop("image_size", None) - load_safety_checker = (kwargs.pop("load_safety_checker", False)) or ( - passed_class_obj.get("safety_checker", None) is not None - ) + init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) + init_kwargs = {k: init_dict.pop(k) for k in optional_kwargs if k in init_dict} + init_kwargs = {**init_kwargs, **passed_pipe_kwargs} + + from mindone.diffusers import pipelines + + # remove `null` components + def load_module(name, value): + if value[0] is None: + return False + if name in passed_class_obj and passed_class_obj[name] is None: + return False + if name in SINGLE_FILE_OPTIONAL_COMPONENTS: + return False + + return True + + init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)} + + for name, (library_name, class_name) in logging.tqdm( + sorted(init_dict.items()), desc="Loading pipeline components..." + ): + loaded_sub_model = None + is_pipeline_module = hasattr(pipelines, library_name) - init_kwargs = {} - for name in expected_modules: if name in passed_class_obj: - init_kwargs[name] = passed_class_obj[name] + loaded_sub_model = passed_class_obj[name] + else: - components = build_sub_model_components( - init_kwargs, - class_name, - name, - original_config, - checkpoint, - model_type=model_type, - image_size=image_size, - load_safety_checker=load_safety_checker, - local_files_only=local_files_only, - mindspore_dtype=mindspore_dtype, - **kwargs, - ) - if not components: - continue - init_kwargs.update(components) + try: + loaded_sub_model = load_single_file_sub_model( + library_name=library_name, + class_name=class_name, + name=name, + checkpoint=checkpoint, + is_pipeline_module=is_pipeline_module, + cached_model_config_path=cached_model_config_path, + pipelines=pipelines, + mindspore_dtype=mindspore_dtype, + original_config=original_config, + local_files_only=local_files_only, + is_legacy_loading=is_legacy_loading, + **kwargs, + ) + except SingleFileComponentError as e: + raise SingleFileComponentError( + ( + f"{e.message}\n" + f"Please load the component before passing it in as an argument to `from_single_file`.\n" + f"\n" + f"{name} = {class_name}.from_pretrained('...')\n" + f"pipe = {pipeline_class.__name__}.from_single_file(, {name}={name})\n" + f"\n" + ) + ) + + init_kwargs[name] = loaded_sub_model + + missing_modules = set(expected_modules) - set(init_kwargs.keys()) + passed_modules = list(passed_class_obj.keys()) + optional_modules = pipeline_class._optional_components + + if len(missing_modules) > 0 and missing_modules <= set(passed_modules + optional_modules): + for module in missing_modules: + init_kwargs[module] = passed_class_obj.get(module, None) + elif len(missing_modules) > 0: + passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) - optional_kwargs + raise ValueError( + f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed." + ) - additional_components = set_additional_components( - class_name, original_config, checkpoint=checkpoint, model_type=model_type - ) - if additional_components: - init_kwargs.update(additional_components) + # deprecated kwargs + load_safety_checker = kwargs.pop("load_safety_checker", None) + if load_safety_checker is not None: + deprecation_message = ( + "Please pass instances of `StableDiffusionSafetyChecker` and `AutoImageProcessor`" + "using the `safety_checker` and `feature_extractor` arguments in `from_single_file`" + ) + deprecate("load_safety_checker", "1.0.0", deprecation_message) + + safety_checker_components = _legacy_load_safety_checker(local_files_only, mindspore_dtype) + init_kwargs.update(safety_checker_components) - init_kwargs.update(passed_pipe_kwargs) pipe = pipeline_class(**init_kwargs) if mindspore_dtype is not None: diff --git a/mindone/diffusers/loaders/single_file_model.py b/mindone/diffusers/loaders/single_file_model.py new file mode 100644 index 0000000000..d7c793ed8b --- /dev/null +++ b/mindone/diffusers/loaders/single_file_model.py @@ -0,0 +1,283 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +import re +from typing import Optional + +from huggingface_hub.utils import validate_hf_hub_args + +from ..models.modeling_utils import _convert_state_dict +from ..utils import deprecate, logging +from .single_file_utils import ( + SingleFileComponentError, + _load_param_into_net, + convert_controlnet_checkpoint, + convert_ldm_unet_checkpoint, + convert_ldm_vae_checkpoint, + convert_sd3_transformer_checkpoint_to_diffusers, + convert_stable_cascade_unet_single_file_to_diffusers, + create_controlnet_diffusers_config_from_ldm, + create_unet_diffusers_config_from_ldm, + create_vae_diffusers_config_from_ldm, + fetch_diffusers_config, + fetch_original_config, + load_single_file_checkpoint, +) + +logger = logging.get_logger(__name__) + + +SINGLE_FILE_LOADABLE_CLASSES = { + "StableCascadeUNet": { + "checkpoint_mapping_fn": convert_stable_cascade_unet_single_file_to_diffusers, + }, + "UNet2DConditionModel": { + "checkpoint_mapping_fn": convert_ldm_unet_checkpoint, + "config_mapping_fn": create_unet_diffusers_config_from_ldm, + "default_subfolder": "unet", + "legacy_kwargs": { + "num_in_channels": "in_channels", # Legacy kwargs supported by `from_single_file` mapped to new args + }, + }, + "AutoencoderKL": { + "checkpoint_mapping_fn": convert_ldm_vae_checkpoint, + "config_mapping_fn": create_vae_diffusers_config_from_ldm, + "default_subfolder": "vae", + }, + "ControlNetModel": { + "checkpoint_mapping_fn": convert_controlnet_checkpoint, + "config_mapping_fn": create_controlnet_diffusers_config_from_ldm, + }, + "SD3Transformer2DModel": { + "checkpoint_mapping_fn": convert_sd3_transformer_checkpoint_to_diffusers, + "default_subfolder": "transformer", + }, +} + + +def _get_mapping_function_kwargs(mapping_fn, **kwargs): + parameters = inspect.signature(mapping_fn).parameters + + mapping_kwargs = {} + for parameter in parameters: + if parameter in kwargs: + mapping_kwargs[parameter] = kwargs[parameter] + + return mapping_kwargs + + +class FromOriginalModelMixin: + """ + Load pretrained weights saved in the `.ckpt` or `.safetensors` format into a model. + """ + + @classmethod + @validate_hf_hub_args + def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = None, **kwargs): + r""" + Instantiate a model from pretrained weights saved in the original `.ckpt` or `.safetensors` format. The model + is set in evaluation mode (`model.eval()`) by default. + + Parameters: + pretrained_model_link_or_path_or_dict (`str`, *optional*): + Can be either: + - A link to the `.safetensors` or `.ckpt` file (for example + `"https://huggingface.co//blob/main/.safetensors"`) on the Hub. + - A path to a local *file* containing the weights of the component model. + - A state dict containing the component model weights. + config (`str`, *optional*): + - A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline hosted + on the Hub. + - A path to a *directory* (for example `./my_pipeline_directory/`) containing the pipeline component + configs in Diffusers format. + subfolder (`str`, *optional*, defaults to `""`): + The subfolder location of a model file within a larger model repository on the Hub or locally. + original_config (`str`, *optional*): + Dict or path to a yaml file containing the configuration for the model in its original format. + If a dict is provided, it will be used to initialize the model configuration. + torch_dtype (`str` or `torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the + dtype is automatically derived from the model's weights. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to resume downloading the model weights and configuration files. If set to `False`, any + incompletely downloaded files are deleted. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to True, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to overwrite load and saveable variables (for example the pipeline components of the + specific pipeline class). The overwritten components are directly passed to the pipelines `__init__` + method. See example below for more information. + + ```py + >>> from mindone.diffusers import StableCascadeUNet + + >>> ckpt_path = "https://huggingface.co/stabilityai/stable-cascade/blob/main/stage_b_lite.safetensors" + >>> model = StableCascadeUNet.from_single_file(ckpt_path) + ``` + """ + + class_name = cls.__name__ + if class_name not in SINGLE_FILE_LOADABLE_CLASSES: + raise ValueError( + f"FromOriginalModelMixin is currently only compatible with {', '.join(SINGLE_FILE_LOADABLE_CLASSES.keys())}" + ) + + pretrained_model_link_or_path = kwargs.get("pretrained_model_link_or_path", None) + if pretrained_model_link_or_path is not None: + deprecation_message = ( + "Please use `pretrained_model_link_or_path_or_dict` argument instead for model classes" + ) + deprecate("pretrained_model_link_or_path", "1.0.0", deprecation_message) + pretrained_model_link_or_path_or_dict = pretrained_model_link_or_path + + config = kwargs.pop("config", None) + original_config = kwargs.pop("original_config", None) + + if config is not None and original_config is not None: + raise ValueError( + "`from_single_file` cannot accept both `config` and `original_config` arguments. Please provide only one of these arguments" + ) + + resume_download = kwargs.pop("resume_download", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + token = kwargs.pop("token", None) + cache_dir = kwargs.pop("cache_dir", None) + local_files_only = kwargs.pop("local_files_only", None) + subfolder = kwargs.pop("subfolder", None) + revision = kwargs.pop("revision", None) + mindspore_dtype = kwargs.pop("mindspore_dtype", None) + + if isinstance(pretrained_model_link_or_path_or_dict, dict): + checkpoint = pretrained_model_link_or_path_or_dict + else: + checkpoint = load_single_file_checkpoint( + pretrained_model_link_or_path_or_dict, + resume_download=resume_download, + force_download=force_download, + proxies=proxies, + token=token, + cache_dir=cache_dir, + local_files_only=local_files_only, + revision=revision, + ) + + mapping_functions = SINGLE_FILE_LOADABLE_CLASSES[class_name] + + checkpoint_mapping_fn = mapping_functions["checkpoint_mapping_fn"] + if original_config: + if "config_mapping_fn" in mapping_functions: + config_mapping_fn = mapping_functions["config_mapping_fn"] + else: + config_mapping_fn = None + + if config_mapping_fn is None: + raise ValueError( + ( + f"`original_config` has been provided for {class_name} but no mapping function" + "was found to convert the original config to a Diffusers config in" + "`diffusers.loaders.single_file_utils`" + ) + ) + + if isinstance(original_config, str): + # If original_config is a URL or filepath fetch the original_config dict + original_config = fetch_original_config(original_config, local_files_only=local_files_only) + + config_mapping_kwargs = _get_mapping_function_kwargs(config_mapping_fn, **kwargs) + diffusers_model_config = config_mapping_fn( + original_config=original_config, checkpoint=checkpoint, **config_mapping_kwargs + ) + else: + if config: + if isinstance(config, str): + default_pretrained_model_config_name = config + else: + raise ValueError( + ( + "Invalid `config` argument. Please provide a string representing a repo id" + "or path to a local Diffusers model repo." + ) + ) + + else: + config = fetch_diffusers_config(checkpoint) + default_pretrained_model_config_name = config["pretrained_model_name_or_path"] + + if "default_subfolder" in mapping_functions: + subfolder = mapping_functions["default_subfolder"] + + subfolder = subfolder or config.pop( + "subfolder", None + ) # some configs contain a subfolder key, e.g. StableCascadeUNet + + diffusers_model_config = cls.load_config( + pretrained_model_name_or_path=default_pretrained_model_config_name, + subfolder=subfolder, + local_files_only=local_files_only, + ) + expected_kwargs, optional_kwargs = cls._get_signature_keys(cls) + + # Map legacy kwargs to new kwargs + if "legacy_kwargs" in mapping_functions: + legacy_kwargs = mapping_functions["legacy_kwargs"] + for legacy_key, new_key in legacy_kwargs.items(): + if legacy_key in kwargs: + kwargs[new_key] = kwargs.pop(legacy_key) + + model_kwargs = {k: kwargs.get(k) for k in kwargs if k in expected_kwargs or k in optional_kwargs} + diffusers_model_config.update(model_kwargs) + + checkpoint_mapping_kwargs = _get_mapping_function_kwargs(checkpoint_mapping_fn, **kwargs) + diffusers_format_checkpoint = checkpoint_mapping_fn( + config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs + ) + if not diffusers_format_checkpoint: + raise SingleFileComponentError( + f"Failed to load {class_name}. Weights for this component appear to be missing in the checkpoint." + ) + + model = cls.from_config(diffusers_model_config) + + diffusers_format_checkpoint = _convert_state_dict(model, diffusers_format_checkpoint) + _, unexpected_keys = _load_param_into_net(model, diffusers_format_checkpoint, mindspore_dtype) + + if model._keys_to_ignore_on_load_unexpected is not None: + for pat in model._keys_to_ignore_on_load_unexpected: + unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] + + if len(unexpected_keys) > 0: + logger.warning( + f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}" + ) + + model.set_train(False) + + return model diff --git a/mindone/diffusers/loaders/single_file_utils.py b/mindone/diffusers/loaders/single_file_utils.py index 8a9fc8c030..95f5c9dce2 100644 --- a/mindone/diffusers/loaders/single_file_utils.py +++ b/mindone/diffusers/loaders/single_file_utils.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" Conversion script for the Stable Diffusion checkpoints.""" +"""Conversion script for the Stable Diffusion checkpoints.""" import os import re @@ -21,17 +21,13 @@ import requests import yaml -from transformers import CLIPTextConfig, CLIPTokenizer import mindspore as ms -from mindspore import Parameter - -from mindone.transformers import CLIPTextModel, CLIPTextModelWithProjection +from mindspore import Parameter, ops from ..models.modeling_utils import _convert_state_dict, load_state_dict from ..schedulers import ( DDIMScheduler, - DDPMScheduler, DPMSolverMultistepScheduler, EDMDPMSolverMultistepScheduler, EulerAncestralDiscreteScheduler, @@ -40,108 +36,77 @@ LMSDiscreteScheduler, PNDMScheduler, ) -from ..utils import logging +from ..utils import SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME, deprecate, is_transformers_available, logging from ..utils.hub_utils import _get_model_file -logger = logging.get_logger(__name__) # pylint: disable=invalid-name +if is_transformers_available(): + from transformers import AutoImageProcessor -CONFIG_URLS = { - "v1": "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml", - "v2": "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml", - "xl": "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml", - "xl_refiner": "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_refiner.yaml", - "upscale": "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/x4-upscaling.yaml", - "controlnet": "https://raw.githubusercontent.com/lllyasviel/ControlNet/main/models/cldm_v15.yaml", -} +logger = logging.get_logger(__name__) # pylint: disable=invalid-name CHECKPOINT_KEY_NAMES = { "v2": "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight", "xl_base": "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias", "xl_refiner": "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias", + "upscale": "model.diffusion_model.input_blocks.10.0.skip_connection.bias", + "controlnet": "control_model.time_embed.0.weight", + "playground-v2-5": "edm_mean", + "inpainting": "model.diffusion_model.input_blocks.0.0.weight", + "clip": "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight", + "clip_sdxl": "conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight", + "clip_sd3": "text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight", + "open_clip": "cond_stage_model.model.token_embedding.weight", + "open_clip_sdxl": "conditioner.embedders.1.model.positional_embedding", + "open_clip_sdxl_refiner": "conditioner.embedders.0.model.text_projection", + "open_clip_sd3": "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight", + "stable_cascade_stage_b": "down_blocks.1.0.channelwise.0.weight", + "stable_cascade_stage_c": "clip_txt_mapper.weight", + "sd3": "model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias", } -SCHEDULER_DEFAULT_CONFIG = { - "beta_schedule": "scaled_linear", - "beta_start": 0.00085, - "beta_end": 0.012, - "interpolation_type": "linear", - "num_train_timesteps": 1000, - "prediction_type": "epsilon", - "sample_max_value": 1.0, - "set_alpha_to_one": False, - "skip_prk_steps": True, - "steps_offset": 1, - "timestep_spacing": "leading", +DIFFUSERS_DEFAULT_PIPELINE_PATHS = { + "xl_base": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-xl-base-1.0"}, + "xl_refiner": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-xl-refiner-1.0"}, + "xl_inpaint": {"pretrained_model_name_or_path": "diffusers/stable-diffusion-xl-1.0-inpainting-0.1"}, + "playground-v2-5": {"pretrained_model_name_or_path": "playgroundai/playground-v2.5-1024px-aesthetic"}, + "upscale": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-x4-upscaler"}, + "inpainting": {"pretrained_model_name_or_path": "runwayml/stable-diffusion-inpainting"}, + "inpainting_v2": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-2-inpainting"}, + "controlnet": {"pretrained_model_name_or_path": "lllyasviel/control_v11p_sd15_canny"}, + "v2": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-2-1"}, + "v1": {"pretrained_model_name_or_path": "runwayml/stable-diffusion-v1-5"}, + "stable_cascade_stage_b": {"pretrained_model_name_or_path": "stabilityai/stable-cascade", "subfolder": "decoder"}, + "stable_cascade_stage_b_lite": { + "pretrained_model_name_or_path": "stabilityai/stable-cascade", + "subfolder": "decoder_lite", + }, + "stable_cascade_stage_c": { + "pretrained_model_name_or_path": "stabilityai/stable-cascade-prior", + "subfolder": "prior", + }, + "stable_cascade_stage_c_lite": { + "pretrained_model_name_or_path": "stabilityai/stable-cascade-prior", + "subfolder": "prior_lite", + }, + "sd3": { + "pretrained_model_name_or_path": "stabilityai/stable-diffusion-3-medium-diffusers", + }, } -STABLE_CASCADE_DEFAULT_CONFIGS = { - "stage_c": {"pretrained_model_name_or_path": "diffusers/stable-cascade-configs", "subfolder": "prior"}, - "stage_c_lite": {"pretrained_model_name_or_path": "diffusers/stable-cascade-configs", "subfolder": "prior_lite"}, - "stage_b": {"pretrained_model_name_or_path": "diffusers/stable-cascade-configs", "subfolder": "decoder"}, - "stage_b_lite": {"pretrained_model_name_or_path": "diffusers/stable-cascade-configs", "subfolder": "decoder_lite"}, +# Use to configure model sample size when original config is provided +DIFFUSERS_TO_LDM_DEFAULT_IMAGE_SIZE_MAP = { + "xl_base": 1024, + "xl_refiner": 1024, + "xl_inpaint": 1024, + "playground-v2-5": 1024, + "upscale": 512, + "inpainting": 512, + "inpainting_v2": 512, + "controlnet": 512, + "v2": 768, + "v1": 512, } - -def convert_stable_cascade_unet_single_file_to_diffusers(original_state_dict): - is_stage_c = "clip_txt_mapper.weight" in original_state_dict - state_dict = {} - for key in original_state_dict.keys(): - if key.endswith("in_proj_weight"): - weights = original_state_dict[key].chunk(3, 0) - state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = Parameter( - weights[0], name=key.replace("attn.in_proj_weight", "to_q.weight") - ) - state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = Parameter( - weights[1], name=key.replace("attn.in_proj_weight", "to_k.weight") - ) - state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = Parameter( - weights[2], name=key.replace("attn.in_proj_weight", "to_v.weight") - ) - elif key.endswith("in_proj_bias"): - weights = original_state_dict[key].chunk(3, 0) - state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = Parameter( - weights[0], name=key.replace("attn.in_proj_bias", "to_q.bias") - ) - state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = Parameter( - weights[1], name=key.replace("attn.in_proj_bias", "to_k.bias") - ) - state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = Parameter( - weights[2], name=key.replace("attn.in_proj_bias", "to_v.bias") - ) - elif key.endswith("out_proj.weight"): - weights = original_state_dict[key] - state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights - elif key.endswith("out_proj.bias"): - weights = original_state_dict[key] - state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights - elif key.endswith("clip_mapper.weight") and not is_stage_c: - weights = original_state_dict[key] - state_dict[key.replace("clip_mapper.weight", "clip_txt_pooled_mapper.weight")] = weights - elif key.endswith("clip_mapper.bias") and not is_stage_c: - weights = original_state_dict[key] - state_dict[key.replace("clip_mapper.bias", "clip_txt_pooled_mapper.bias")] = weights - else: - state_dict[key] = original_state_dict[key] - - return state_dict - - -def infer_stable_cascade_single_file_config(checkpoint): - is_stage_c = "clip_txt_mapper.weight" in checkpoint - is_stage_b = "down_blocks.1.0.channelwise.0.weight" in checkpoint - config_type = None - if is_stage_c and (checkpoint["clip_txt_mapper.weight"].shape[0] == 1536): - config_type = "stage_c_lite" - elif is_stage_c and (checkpoint["clip_txt_mapper.weight"].shape[0] == 2048): - config_type = "stage_c" - elif is_stage_b and checkpoint["down_blocks.1.0.channelwise.0.weight"].shape[-1] == 576: - config_type = "stage_b_lite" - elif is_stage_b and checkpoint["down_blocks.1.0.channelwise.0.weight"].shape[-1] == 640: - config_type = "stage_b" - - return STABLE_CASCADE_DEFAULT_CONFIGS[config_type] - - DIFFUSERS_TO_LDM_MAPPING = { "unet": { "layers": { @@ -235,14 +200,6 @@ def infer_stable_cascade_single_file_config(checkpoint): }, } -LDM_VAE_KEY = "first_stage_model." -LDM_VAE_DEFAULT_SCALING_FACTOR = 0.18215 -PLAYGROUND_VAE_SCALING_FACTOR = 0.5 -LDM_UNET_KEY = "model.diffusion_model." -LDM_CONTROLNET_KEY = "control_model." -LDM_CLIP_PREFIX_TO_REMOVE = ["cond_stage_model.transformer.", "conditioner.embedders.0.transformer."] -LDM_OPEN_CLIP_TEXT_PROJECTION_DIM = 1024 - SD_2_TEXT_ENCODER_KEYS_TO_IGNORE = [ "cond_stage_model.model.transformer.resblocks.23.attn.in_proj_bias", "cond_stage_model.model.transformer.resblocks.23.attn.in_proj_weight", @@ -259,11 +216,54 @@ def infer_stable_cascade_single_file_config(checkpoint): "cond_stage_model.model.text_projection", ] +# To support legacy scheduler_type argument +SCHEDULER_DEFAULT_CONFIG = { + "beta_schedule": "scaled_linear", + "beta_start": 0.00085, + "beta_end": 0.012, + "interpolation_type": "linear", + "num_train_timesteps": 1000, + "prediction_type": "epsilon", + "sample_max_value": 1.0, + "set_alpha_to_one": False, + "skip_prk_steps": True, + "steps_offset": 1, + "timestep_spacing": "leading", +} + +LDM_VAE_KEY = "first_stage_model." +LDM_VAE_DEFAULT_SCALING_FACTOR = 0.18215 +PLAYGROUND_VAE_SCALING_FACTOR = 0.5 +LDM_UNET_KEY = "model.diffusion_model." +LDM_CONTROLNET_KEY = "control_model." +LDM_CLIP_PREFIX_TO_REMOVE = [ + "cond_stage_model.transformer.", + "conditioner.embedders.0.transformer.", +] +OPEN_CLIP_PREFIX = "conditioner.embedders.0.model." +LDM_OPEN_CLIP_TEXT_PROJECTION_DIM = 1024 VALID_URL_PREFIXES = ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"] +class SingleFileComponentError(Exception): + def __init__(self, message=None): + self.message = message + super().__init__(self.message) + + +def is_valid_url(url): + result = urlparse(url) + if result.scheme and result.netloc: + return True + + return False + + def _extract_repo_id_and_weights_name(pretrained_model_name_or_path): + if not is_valid_url(pretrained_model_name_or_path): + raise ValueError("Invalid `pretrained_model_name_or_path` provided. Please set it to a valid URL.") + pattern = r"([^/]+)/([^/]+)/(?:blob/main/)?(.+)" weights_name = None repo_id = (None,) @@ -271,6 +271,7 @@ def _extract_repo_id_and_weights_name(pretrained_model_name_or_path): pretrained_model_name_or_path = pretrained_model_name_or_path.replace(prefix, "") match = re.match(pattern, pretrained_model_name_or_path) if not match: + logger.warning("Unable to identify the repo_id and weights_name from the provided URL.") return repo_id, weights_name repo_id = f"{match.group(1)}/{match.group(2)}" @@ -279,34 +280,18 @@ def _extract_repo_id_and_weights_name(pretrained_model_name_or_path): return repo_id, weights_name -def fetch_ldm_config_and_checkpoint( - pretrained_model_link_or_path, - class_name, - original_config_file=None, - resume_download=False, - force_download=False, - proxies=None, - token=None, - cache_dir=None, - local_files_only=None, - revision=None, -): - checkpoint = load_single_file_model_checkpoint( - pretrained_model_link_or_path, - resume_download=resume_download, - force_download=force_download, - proxies=proxies, - token=token, - cache_dir=cache_dir, - local_files_only=local_files_only, - revision=revision, - ) - original_config = fetch_original_config(class_name, checkpoint, original_config_file) +def _is_model_weights_in_cached_folder(cached_folder, name): + pretrained_model_name_or_path = os.path.join(cached_folder, name) + weights_exist = False - return original_config, checkpoint + for weights_name in [WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME]: + if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)): + weights_exist = True + return weights_exist -def load_single_file_model_checkpoint( + +def load_single_file_checkpoint( pretrained_model_link_or_path, resume_download=False, force_download=False, @@ -317,10 +302,10 @@ def load_single_file_model_checkpoint( revision=None, ): if os.path.isfile(pretrained_model_link_or_path): - checkpoint = load_state_dict(pretrained_model_link_or_path) + pretrained_model_link_or_path = pretrained_model_link_or_path else: repo_id, weights_name = _extract_repo_id_and_weights_name(pretrained_model_link_or_path) - checkpoint_path = _get_model_file( + pretrained_model_link_or_path = _get_model_file( repo_id, weights_name=weights_name, force_download=force_download, @@ -331,7 +316,8 @@ def load_single_file_model_checkpoint( token=token, revision=revision, ) - checkpoint = load_state_dict(checkpoint_path) + + checkpoint = load_state_dict(pretrained_model_link_or_path) # some checkpoints contain the model state dict under a "state_dict" key while "state_dict" in checkpoint: @@ -340,120 +326,173 @@ def load_single_file_model_checkpoint( return checkpoint -def infer_original_config_file(class_name, checkpoint): - if CHECKPOINT_KEY_NAMES["v2"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["v2"]].shape[-1] == 1024: - config_url = CONFIG_URLS["v2"] +def fetch_original_config(original_config_file, local_files_only=False): + if os.path.isfile(original_config_file): + with open(original_config_file, "r") as fp: + original_config_file = fp.read() - elif CHECKPOINT_KEY_NAMES["xl_base"] in checkpoint: - config_url = CONFIG_URLS["xl"] + elif is_valid_url(original_config_file): + if local_files_only: + raise ValueError( + "`local_files_only` is set to True, but a URL was provided as `original_config_file`. " + "Please provide a valid local file path." + ) - elif CHECKPOINT_KEY_NAMES["xl_refiner"] in checkpoint: - config_url = CONFIG_URLS["xl_refiner"] + original_config_file = BytesIO(requests.get(original_config_file).content) - elif class_name == "StableDiffusionUpscalePipeline": - config_url = CONFIG_URLS["upscale"] + else: + raise ValueError("Invalid `original_config_file` provided. Please set it to a valid file path or URL.") - elif class_name == "ControlNetModel": - config_url = CONFIG_URLS["controlnet"] + original_config = yaml.safe_load(original_config_file) - else: - config_url = CONFIG_URLS["v1"] + return original_config - original_config_file = BytesIO(requests.get(config_url).content) - return original_config_file +def is_clip_model(checkpoint): + if CHECKPOINT_KEY_NAMES["clip"] in checkpoint: + return True + return False -def fetch_original_config(pipeline_class_name, checkpoint, original_config_file=None): - def is_valid_url(url): - result = urlparse(url) - if result.scheme and result.netloc: - return True - return False +def is_clip_sdxl_model(checkpoint): + if CHECKPOINT_KEY_NAMES["clip_sdxl"] in checkpoint: + return True - if original_config_file is None: - original_config_file = infer_original_config_file(pipeline_class_name, checkpoint) + return False - elif os.path.isfile(original_config_file): - with open(original_config_file, "r") as fp: - original_config_file = fp.read() - elif is_valid_url(original_config_file): - original_config_file = BytesIO(requests.get(original_config_file).content) +def is_clip_sd3_model(checkpoint): + if CHECKPOINT_KEY_NAMES["clip_sd3"] in checkpoint: + return True - else: - raise ValueError("Invalid `original_config_file` provided. Please set it to a valid file path or URL.") + return False - original_config = yaml.safe_load(original_config_file) - return original_config +def is_open_clip_model(checkpoint): + if CHECKPOINT_KEY_NAMES["open_clip"] in checkpoint: + return True + return False -def infer_model_type(original_config, checkpoint, model_type=None): - if model_type is not None: - return model_type - has_cond_stage_config = ( - "cond_stage_config" in original_config["model"]["params"] - and original_config["model"]["params"]["cond_stage_config"] is not None - ) - has_network_config = ( - "network_config" in original_config["model"]["params"] - and original_config["model"]["params"]["network_config"] is not None +def is_open_clip_sdxl_model(checkpoint): + if CHECKPOINT_KEY_NAMES["open_clip_sdxl"] in checkpoint: + return True + + return False + + +def is_open_clip_sd3_model(checkpoint): + if CHECKPOINT_KEY_NAMES["open_clip_sd3"] in checkpoint: + return True + + return False + + +def is_open_clip_sdxl_refiner_model(checkpoint): + if CHECKPOINT_KEY_NAMES["open_clip_sdxl_refiner"] in checkpoint: + return True + + return False + + +def is_clip_model_in_single_file(class_obj, checkpoint): + is_clip_in_checkpoint = any( + [ + is_clip_model(checkpoint), + is_clip_sd3_model(checkpoint), + is_open_clip_model(checkpoint), + is_open_clip_sdxl_model(checkpoint), + is_open_clip_sdxl_refiner_model(checkpoint), + is_open_clip_sd3_model(checkpoint), + ] ) + if ( + class_obj.__name__ == "CLIPTextModel" or class_obj.__name__ == "CLIPTextModelWithProjection" + ) and is_clip_in_checkpoint: + return True - if has_cond_stage_config: - model_type = original_config["model"]["params"]["cond_stage_config"]["target"].split(".")[-1] + return False - elif has_network_config: - context_dim = original_config["model"]["params"]["network_config"]["params"]["context_dim"] - if "edm_mean" in checkpoint and "edm_std" in checkpoint: - model_type = "Playground" - elif context_dim == 2048: - model_type = "SDXL" + +def infer_diffusers_model_type(checkpoint): + if ( + CHECKPOINT_KEY_NAMES["inpainting"] in checkpoint + and checkpoint[CHECKPOINT_KEY_NAMES["inpainting"]].shape[1] == 9 + ): + if CHECKPOINT_KEY_NAMES["v2"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["v2"]].shape[-1] == 1024: + model_type = "inpainting_v2" else: - model_type = "SDXL-Refiner" - else: - raise ValueError("Unable to infer model type from config") + model_type = "inpainting" - logger.debug(f"No `model_type` given, `model_type` inferred as: {model_type}") + elif CHECKPOINT_KEY_NAMES["v2"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["v2"]].shape[-1] == 1024: + model_type = "v2" - return model_type + elif CHECKPOINT_KEY_NAMES["playground-v2-5"] in checkpoint: + model_type = "playground-v2-5" + elif CHECKPOINT_KEY_NAMES["xl_base"] in checkpoint: + model_type = "xl_base" -def get_default_scheduler_config(): - return SCHEDULER_DEFAULT_CONFIG + elif CHECKPOINT_KEY_NAMES["xl_refiner"] in checkpoint: + model_type = "xl_refiner" + elif CHECKPOINT_KEY_NAMES["upscale"] in checkpoint: + model_type = "upscale" -def set_image_size(pipeline_class_name, original_config, checkpoint, image_size=None, model_type=None): - if image_size: - return image_size + elif CHECKPOINT_KEY_NAMES["controlnet"] in checkpoint: + model_type = "controlnet" - global_step = checkpoint["global_step"] if "global_step" in checkpoint else None - model_type = infer_model_type(original_config, checkpoint, model_type) + elif ( + CHECKPOINT_KEY_NAMES["stable_cascade_stage_c"] in checkpoint + and checkpoint[CHECKPOINT_KEY_NAMES["stable_cascade_stage_c"]].shape[0] == 1536 + ): + model_type = "stable_cascade_stage_c_lite" - if pipeline_class_name == "StableDiffusionUpscalePipeline": - image_size = original_config["model"]["params"]["unet_config"]["params"]["image_size"] - return image_size + elif ( + CHECKPOINT_KEY_NAMES["stable_cascade_stage_c"] in checkpoint + and checkpoint[CHECKPOINT_KEY_NAMES["stable_cascade_stage_c"]].shape[0] == 2048 + ): + model_type = "stable_cascade_stage_c" - elif model_type in ["SDXL", "SDXL-Refiner", "Playground"]: - image_size = 1024 - return image_size + elif ( + CHECKPOINT_KEY_NAMES["stable_cascade_stage_b"] in checkpoint + and checkpoint[CHECKPOINT_KEY_NAMES["stable_cascade_stage_b"]].shape[-1] == 576 + ): + model_type = "stable_cascade_stage_b_lite" elif ( - "parameterization" in original_config["model"]["params"] - and original_config["model"]["params"]["parameterization"] == "v" + CHECKPOINT_KEY_NAMES["stable_cascade_stage_b"] in checkpoint + and checkpoint[CHECKPOINT_KEY_NAMES["stable_cascade_stage_b"]].shape[-1] == 640 ): - # NOTE: For stable diffusion 2 base one has to pass `image_size==512` - # as it relies on a brittle global step parameter here - image_size = 512 if global_step == 875000 else 768 - return image_size + model_type = "stable_cascade_stage_b" + + elif CHECKPOINT_KEY_NAMES["sd3"] in checkpoint: + model_type = "sd3" else: - image_size = 512 + model_type = "v1" + + return model_type + + +def fetch_diffusers_config(checkpoint): + model_type = infer_diffusers_model_type(checkpoint) + model_path = DIFFUSERS_DEFAULT_PIPELINE_PATHS[model_type] + + return model_path + + +def set_image_size(checkpoint, image_size=None): + if image_size: return image_size + model_type = infer_diffusers_model_type(checkpoint) + image_size = DIFFUSERS_TO_LDM_DEFAULT_IMAGE_SIZE_MAP[model_type] + + return image_size + # Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.conv_attn_to_linear def conv_attn_to_linear(checkpoint): @@ -462,16 +501,27 @@ def conv_attn_to_linear(checkpoint): for key in keys: if ".".join(key.split(".")[-2:]) in attn_keys: if checkpoint[key].ndim > 2: - checkpoint[key] = checkpoint[key][:, :, 0, 0] + checkpoint[key] = Parameter(checkpoint[key][:, :, 0, 0], name=key) elif "proj_attn.weight" in key: if checkpoint[key].ndim > 2: - checkpoint[key] = checkpoint[key][:, :, 0] + checkpoint[key] = Parameter(checkpoint[key][:, :, 0], name=key) -def create_unet_diffusers_config(original_config, image_size: int): +def create_unet_diffusers_config_from_ldm( + original_config, checkpoint, image_size=None, upcast_attention=None, num_in_channels=None +): """ Creates a config for the diffusers based on the config of the LDM model. """ + if image_size is not None: + deprecation_message = ( + "Configuring UNet2DConditionModel with the `image_size` argument to `from_single_file`" + "is deprecated and will be ignored in future versions." + ) + deprecate("image_size", "1.0.0", deprecation_message) + + image_size = set_image_size(checkpoint, image_size=image_size) + if ( "unet_config" in original_config["model"]["params"] and original_config["model"]["params"]["unet_config"] is not None @@ -480,6 +530,16 @@ def create_unet_diffusers_config(original_config, image_size: int): else: unet_params = original_config["model"]["params"]["network_config"]["params"] + if num_in_channels is not None: + deprecation_message = ( + "Configuring UNet2DConditionModel with the `num_in_channels` argument to `from_single_file`" + "is deprecated and will be ignored in future versions." + ) + deprecate("image_size", "1.0.0", deprecation_message) + in_channels = num_in_channels + else: + in_channels = unet_params["in_channels"] + vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"] block_out_channels = [unet_params["model_channels"] * mult for mult in unet_params["channel_mult"]] @@ -542,7 +602,7 @@ def create_unet_diffusers_config(original_config, image_size: int): config = { "sample_size": image_size // vae_scale_factor, - "in_channels": unet_params["in_channels"], + "in_channels": in_channels, "down_block_types": down_block_types, "block_out_channels": block_out_channels, "layers_per_block": unet_params["num_res_blocks"], @@ -556,6 +616,14 @@ def create_unet_diffusers_config(original_config, image_size: int): "transformer_layers_per_block": transformer_layers_per_block, } + if upcast_attention is not None: + deprecation_message = ( + "Configuring UNet2DConditionModel with the `upcast_attention` argument to `from_single_file`" + "is deprecated and will be ignored in future versions." + ) + deprecate("image_size", "1.0.0", deprecation_message) + config["upcast_attention"] = upcast_attention + if "disable_self_attentions" in unet_params: config["only_cross_attention"] = unet_params["disable_self_attentions"] @@ -568,9 +636,18 @@ def create_unet_diffusers_config(original_config, image_size: int): return config -def create_controlnet_diffusers_config(original_config, image_size: int): +def create_controlnet_diffusers_config_from_ldm(original_config, checkpoint, image_size=None, **kwargs): + if image_size is not None: + deprecation_message = ( + "Configuring ControlNetModel with the `image_size` argument" + "is deprecated and will be ignored in future versions." + ) + deprecate("image_size", "1.0.0", deprecation_message) + + image_size = set_image_size(checkpoint, image_size=image_size) + unet_params = original_config["model"]["params"]["control_stage_config"]["params"] - diffusers_unet_config = create_unet_diffusers_config(original_config, image_size=image_size) + diffusers_unet_config = create_unet_diffusers_config_from_ldm(original_config, image_size=image_size) controlnet_config = { "conditioning_channels": unet_params["hint_channels"], @@ -591,15 +668,33 @@ def create_controlnet_diffusers_config(original_config, image_size: int): return controlnet_config -def create_vae_diffusers_config(original_config, image_size, scaling_factor=None, latents_mean=None, latents_std=None): +def create_vae_diffusers_config_from_ldm(original_config, checkpoint, image_size=None, scaling_factor=None): """ Creates a config for the diffusers based on the config of the LDM model. """ + if image_size is not None: + deprecation_message = ( + "Configuring AutoencoderKL with the `image_size` argument" + "is deprecated and will be ignored in future versions." + ) + deprecate("image_size", "1.0.0", deprecation_message) + + image_size = set_image_size(checkpoint, image_size=image_size) + + if "edm_mean" in checkpoint and "edm_std" in checkpoint: + latents_mean = checkpoint["edm_mean"] + latents_std = checkpoint["edm_std"] + else: + latents_mean = None + latents_std = None + vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"] if (scaling_factor is None) and (latents_mean is not None) and (latents_std is not None): scaling_factor = PLAYGROUND_VAE_SCALING_FACTOR + elif (scaling_factor is None) and ("scale_factor" in original_config["model"]["params"]): scaling_factor = original_config["model"]["params"]["scale_factor"] + elif scaling_factor is None: scaling_factor = LDM_VAE_DEFAULT_SCALING_FACTOR @@ -636,7 +731,7 @@ def update_unet_resnet_ldm_to_diffusers(ldm_keys, new_checkpoint, checkpoint, ma ) if mapping: diffusers_key = diffusers_key.replace(mapping["old"], mapping["new"]) - new_checkpoint[diffusers_key] = checkpoint.pop(ldm_key) + new_checkpoint[diffusers_key] = checkpoint.get(ldm_key) def update_unet_attention_ldm_to_diffusers(ldm_keys, new_checkpoint, checkpoint, mapping): @@ -645,7 +740,119 @@ def update_unet_attention_ldm_to_diffusers(ldm_keys, new_checkpoint, checkpoint, new_checkpoint[diffusers_key] = checkpoint.pop(ldm_key) -def convert_ldm_unet_checkpoint(checkpoint, config, extract_ema=False): +def update_vae_resnet_ldm_to_diffusers(keys, new_checkpoint, checkpoint, mapping): + for ldm_key in keys: + diffusers_key = ldm_key.replace(mapping["old"], mapping["new"]).replace("nin_shortcut", "conv_shortcut") + new_checkpoint[diffusers_key] = checkpoint.get(ldm_key) + + +def update_vae_attentions_ldm_to_diffusers(keys, new_checkpoint, checkpoint, mapping): + for ldm_key in keys: + diffusers_key = ( + ldm_key.replace(mapping["old"], mapping["new"]) + .replace("norm.weight", "group_norm.weight") + .replace("norm.bias", "group_norm.bias") + .replace("q.weight", "to_q.weight") + .replace("q.bias", "to_q.bias") + .replace("k.weight", "to_k.weight") + .replace("k.bias", "to_k.bias") + .replace("v.weight", "to_v.weight") + .replace("v.bias", "to_v.bias") + .replace("proj_out.weight", "to_out.0.weight") + .replace("proj_out.bias", "to_out.0.bias") + ) + new_checkpoint[diffusers_key] = checkpoint.get(ldm_key) + + # proj_attn.weight has to be converted from conv 1D to linear + shape = new_checkpoint[diffusers_key].shape + + if len(shape) == 3: + new_checkpoint[diffusers_key] = Parameter(new_checkpoint[diffusers_key][:, :, 0], name=diffusers_key) + elif len(shape) == 4: + new_checkpoint[diffusers_key] = Parameter(new_checkpoint[diffusers_key][:, :, 0, 0], name=diffusers_key) + + +def convert_stable_cascade_unet_single_file_to_diffusers(checkpoint, **kwargs): + is_stage_c = "clip_txt_mapper.weight" in checkpoint + + if is_stage_c: + state_dict = {} + for key in checkpoint.keys(): + if key.endswith("in_proj_weight"): + weights = checkpoint[key].chunk(3, 0) + state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = Parameter( + weights[0], name=key.replace("attn.in_proj_weight", "to_q.weight") + ) + state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = Parameter( + weights[1], name=key.replace("attn.in_proj_weight", "to_k.weight") + ) + state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = Parameter( + weights[2], name=key.replace("attn.in_proj_weight", "to_v.weight") + ) + elif key.endswith("in_proj_bias"): + weights = checkpoint[key].chunk(3, 0) + state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = Parameter( + weights[0], name=key.replace("attn.in_proj_bias", "to_q.bias") + ) + state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = Parameter( + weights[1], name=key.replace("attn.in_proj_bias", "to_k.bias") + ) + state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = Parameter( + weights[2], name=key.replace("attn.in_proj_bias", "to_v.bias") + ) + elif key.endswith("out_proj.weight"): + weights = checkpoint[key] + state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights + elif key.endswith("out_proj.bias"): + weights = checkpoint[key] + state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights + else: + state_dict[key] = checkpoint[key] + else: + state_dict = {} + for key in checkpoint.keys(): + if key.endswith("in_proj_weight"): + weights = checkpoint[key].chunk(3, 0) + state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = Parameter( + weights[0], name=key.replace("attn.in_proj_weight", "to_q.weight") + ) + state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = Parameter( + weights[1], name=key.replace("attn.in_proj_weight", "to_k.weight") + ) + state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = Parameter( + weights[2], name=key.replace("attn.in_proj_weight", "to_v.weight") + ) + elif key.endswith("in_proj_bias"): + weights = checkpoint[key].chunk(3, 0) + state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = Parameter( + weights[0], name=key.replace("attn.in_proj_bias", "to_q.bias") + ) + state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = Parameter( + weights[1], name=key.replace("attn.in_proj_bias", "to_k.bias") + ) + state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = Parameter( + weights[2], name=key.replace("attn.in_proj_bias", "to_v.bias") + ) + elif key.endswith("out_proj.weight"): + weights = checkpoint[key] + state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights + elif key.endswith("out_proj.bias"): + weights = checkpoint[key] + state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights + # rename clip_mapper to clip_txt_pooled_mapper + elif key.endswith("clip_mapper.weight"): + weights = checkpoint[key] + state_dict[key.replace("clip_mapper.weight", "clip_txt_pooled_mapper.weight")] = weights + elif key.endswith("clip_mapper.bias"): + weights = checkpoint[key] + state_dict[key.replace("clip_mapper.bias", "clip_txt_pooled_mapper.bias")] = weights + else: + state_dict[key] = checkpoint[key] + + return state_dict + + +def convert_ldm_unet_checkpoint(checkpoint, config, extract_ema=False, **kwargs): """ Takes a state dict and a config, and returns a converted checkpoint. """ @@ -664,7 +871,7 @@ def convert_ldm_unet_checkpoint(checkpoint, config, extract_ema=False): for key in keys: if key.startswith("model.diffusion_model"): flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) - unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) + unet_state_dict[key.replace(unet_key, "")] = checkpoint.get(flat_ema_key) else: if sum(k.startswith("model_ema") for k in keys) > 100: logger.warning( @@ -673,7 +880,7 @@ def convert_ldm_unet_checkpoint(checkpoint, config, extract_ema=False): ) for key in keys: if key.startswith(unet_key): - unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) + unet_state_dict[key.replace(unet_key, "")] = checkpoint.get(key) new_checkpoint = {} ldm_unet_keys = DIFFUSERS_TO_LDM_MAPPING["unet"]["layers"] @@ -734,10 +941,10 @@ def convert_ldm_unet_checkpoint(checkpoint, config, extract_ema=False): ) if f"input_blocks.{i}.0.op.weight" in unet_state_dict: - new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.get( f"input_blocks.{i}.0.op.weight" ) - new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.get( f"input_blocks.{i}.0.op.bias" ) @@ -751,19 +958,22 @@ def convert_ldm_unet_checkpoint(checkpoint, config, extract_ema=False): ) # Mid blocks - resnet_0 = middle_blocks[0] - attentions = middle_blocks[1] - resnet_1 = middle_blocks[2] - - update_unet_resnet_ldm_to_diffusers( - resnet_0, new_checkpoint, unet_state_dict, mapping={"old": "middle_block.0", "new": "mid_block.resnets.0"} - ) - update_unet_resnet_ldm_to_diffusers( - resnet_1, new_checkpoint, unet_state_dict, mapping={"old": "middle_block.2", "new": "mid_block.resnets.1"} - ) - update_unet_attention_ldm_to_diffusers( - attentions, new_checkpoint, unet_state_dict, mapping={"old": "middle_block.1", "new": "mid_block.attentions.0"} - ) + for key in middle_blocks.keys(): + diffusers_key = max(key - 1, 0) + if key % 2 == 0: + update_unet_resnet_ldm_to_diffusers( + middle_blocks[key], + new_checkpoint, + unet_state_dict, + mapping={"old": f"middle_block.{key}", "new": f"mid_block.resnets.{diffusers_key}"}, + ) + else: + update_unet_attention_ldm_to_diffusers( + middle_blocks[key], + new_checkpoint, + unet_state_dict, + mapping={"old": f"middle_block.{key}", "new": f"mid_block.attentions.{diffusers_key}"}, + ) # Up Blocks for i in range(num_output_blocks): @@ -812,6 +1022,7 @@ def convert_ldm_unet_checkpoint(checkpoint, config, extract_ema=False): def convert_controlnet_checkpoint( checkpoint, config, + **kwargs, ): # Some controlnet ckpt files are distributed independently from the rest of the # model components i.e. https://huggingface.co/thibaud/controlnet-sd21/ @@ -824,7 +1035,7 @@ def convert_controlnet_checkpoint( controlnet_key = LDM_CONTROLNET_KEY for key in keys: if key.startswith(controlnet_key): - controlnet_state_dict[key.replace(controlnet_key, "")] = checkpoint.pop(key) + controlnet_state_dict[key.replace(controlnet_key, "")] = checkpoint.get(key) new_checkpoint = {} ldm_controlnet_keys = DIFFUSERS_TO_LDM_MAPPING["controlnet"]["layers"] @@ -858,10 +1069,10 @@ def convert_controlnet_checkpoint( ) if f"input_blocks.{i}.0.op.weight" in controlnet_state_dict: - new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = controlnet_state_dict.pop( + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = controlnet_state_dict.get( f"input_blocks.{i}.0.op.weight" ) - new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = controlnet_state_dict.pop( + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = controlnet_state_dict.get( f"input_blocks.{i}.0.op.bias" ) @@ -876,8 +1087,8 @@ def convert_controlnet_checkpoint( # controlnet down blocks for i in range(num_input_blocks): - new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = controlnet_state_dict.pop(f"zero_convs.{i}.0.weight") - new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = controlnet_state_dict.pop(f"zero_convs.{i}.0.bias") + new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = controlnet_state_dict.get(f"zero_convs.{i}.0.weight") + new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = controlnet_state_dict.get(f"zero_convs.{i}.0.bias") # Retrieves the keys for the middle blocks only num_middle_blocks = len( @@ -887,33 +1098,28 @@ def convert_controlnet_checkpoint( layer_id: [key for key in controlnet_state_dict if f"middle_block.{layer_id}" in key] for layer_id in range(num_middle_blocks) } - if middle_blocks: - resnet_0 = middle_blocks[0] - attentions = middle_blocks[1] - resnet_1 = middle_blocks[2] - update_unet_resnet_ldm_to_diffusers( - resnet_0, - new_checkpoint, - controlnet_state_dict, - mapping={"old": "middle_block.0", "new": "mid_block.resnets.0"}, - ) - update_unet_resnet_ldm_to_diffusers( - resnet_1, - new_checkpoint, - controlnet_state_dict, - mapping={"old": "middle_block.2", "new": "mid_block.resnets.1"}, - ) - update_unet_attention_ldm_to_diffusers( - attentions, - new_checkpoint, - controlnet_state_dict, - mapping={"old": "middle_block.1", "new": "mid_block.attentions.0"}, - ) + # Mid blocks + for key in middle_blocks.keys(): + diffusers_key = max(key - 1, 0) + if key % 2 == 0: + update_unet_resnet_ldm_to_diffusers( + middle_blocks[key], + new_checkpoint, + controlnet_state_dict, + mapping={"old": f"middle_block.{key}", "new": f"mid_block.resnets.{diffusers_key}"}, + ) + else: + update_unet_attention_ldm_to_diffusers( + middle_blocks[key], + new_checkpoint, + controlnet_state_dict, + mapping={"old": f"middle_block.{key}", "new": f"mid_block.attentions.{diffusers_key}"}, + ) # mid block - new_checkpoint["controlnet_mid_block.weight"] = controlnet_state_dict.pop("middle_block_out.0.weight") - new_checkpoint["controlnet_mid_block.bias"] = controlnet_state_dict.pop("middle_block_out.0.bias") + new_checkpoint["controlnet_mid_block.weight"] = controlnet_state_dict.get("middle_block_out.0.weight") + new_checkpoint["controlnet_mid_block.bias"] = controlnet_state_dict.get("middle_block_out.0.bias") # controlnet cond embedding blocks cond_embedding_blocks = { @@ -927,69 +1133,16 @@ def convert_controlnet_checkpoint( diffusers_idx = idx - 1 cond_block_id = 2 * idx - new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_idx}.weight"] = controlnet_state_dict.pop( + new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_idx}.weight"] = controlnet_state_dict.get( f"input_hint_block.{cond_block_id}.weight" ) - new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_idx}.bias"] = controlnet_state_dict.pop( + new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_idx}.bias"] = controlnet_state_dict.get( f"input_hint_block.{cond_block_id}.bias" ) return new_checkpoint -def create_diffusers_controlnet_model_from_ldm( - pipeline_class_name, original_config, checkpoint, upcast_attention=False, image_size=None, mindspore_dtype=None -): - # import here to avoid circular imports - from ..models import ControlNetModel - - image_size = set_image_size(pipeline_class_name, original_config, checkpoint, image_size=image_size) - - diffusers_config = create_controlnet_diffusers_config(original_config, image_size=image_size) - diffusers_config["upcast_attention"] = upcast_attention - - diffusers_format_controlnet_checkpoint = convert_controlnet_checkpoint(checkpoint, diffusers_config) - controlnet = ControlNetModel(**diffusers_config) - _load_param_into_net(controlnet, diffusers_format_controlnet_checkpoint) - - if mindspore_dtype is not None: - controlnet = controlnet.to(mindspore_dtype) - - return {"controlnet": controlnet} - - -def update_vae_resnet_ldm_to_diffusers(keys, new_checkpoint, checkpoint, mapping): - for ldm_key in keys: - diffusers_key = ldm_key.replace(mapping["old"], mapping["new"]).replace("nin_shortcut", "conv_shortcut") - new_checkpoint[diffusers_key] = checkpoint.pop(ldm_key) - - -def update_vae_attentions_ldm_to_diffusers(keys, new_checkpoint, checkpoint, mapping): - for ldm_key in keys: - diffusers_key = ( - ldm_key.replace(mapping["old"], mapping["new"]) - .replace("norm.weight", "group_norm.weight") - .replace("norm.bias", "group_norm.bias") - .replace("q.weight", "to_q.weight") - .replace("q.bias", "to_q.bias") - .replace("k.weight", "to_k.weight") - .replace("k.bias", "to_k.bias") - .replace("v.weight", "to_v.weight") - .replace("v.bias", "to_v.bias") - .replace("proj_out.weight", "to_out.0.weight") - .replace("proj_out.bias", "to_out.0.bias") - ) - new_checkpoint[diffusers_key] = checkpoint.pop(ldm_key) - - # proj_attn.weight has to be converted from conv 1D to linear - shape = new_checkpoint[diffusers_key].shape - - if len(shape) == 3: - new_checkpoint[diffusers_key] = Parameter(new_checkpoint[diffusers_key][:, :, 0], name=diffusers_key) - elif len(shape) == 4: - new_checkpoint[diffusers_key] = Parameter(new_checkpoint[diffusers_key][:, :, 0, 0], name=diffusers_key) - - def convert_ldm_vae_checkpoint(checkpoint, config): # extract state dict for VAE # remove the LDM_VAE_KEY prefix from the ldm checkpoint keys so that it is easier to map them to diffusers keys @@ -1022,10 +1175,10 @@ def convert_ldm_vae_checkpoint(checkpoint, config): mapping={"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}, ) if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: - new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.get( f"encoder.down.{i}.downsample.conv.weight" ) - new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.get( f"encoder.down.{i}.downsample.conv.bias" ) @@ -1090,65 +1243,38 @@ def convert_ldm_vae_checkpoint(checkpoint, config): return new_checkpoint -def create_text_encoder_from_ldm_clip_checkpoint(config_name, checkpoint, local_files_only=False, mindspore_dtype=None): - try: - config = CLIPTextConfig.from_pretrained(config_name, local_files_only=local_files_only) - except Exception: - raise ValueError( - f"With local_files_only set to {local_files_only}, you must first locally save the configuration \ - in the following path:'openai/clip-vit-large-patch14'." - ) - - text_model = CLIPTextModel(config) - +def convert_ldm_clip_checkpoint(checkpoint, remove_prefix=None): keys = list(checkpoint.keys()) text_model_dict = {} - remove_prefixes = LDM_CLIP_PREFIX_TO_REMOVE + remove_prefixes = [] + remove_prefixes.extend(LDM_CLIP_PREFIX_TO_REMOVE) + if remove_prefix: + remove_prefixes.append(remove_prefix) for key in keys: for prefix in remove_prefixes: if key.startswith(prefix): diffusers_key = key.replace(prefix, "") - text_model_dict[diffusers_key] = checkpoint[key] - - if not (hasattr(text_model, "embeddings") and hasattr(text_model.embeddings.position_ids)): - text_model_dict.pop("text_model.embeddings.position_ids", None) - text_model_dict_ms = _convert_state_dict(text_model, text_model_dict) - _load_param_into_net(text_model, text_model_dict_ms) + text_model_dict[diffusers_key] = checkpoint.get(key) - if mindspore_dtype is not None: - text_model = text_model.to(mindspore_dtype) + return text_model_dict - return text_model - -def create_text_encoder_from_open_clip_checkpoint( - config_name, +def convert_open_clip_checkpoint( + text_model, checkpoint, prefix="cond_stage_model.model.", - has_projection=False, - local_files_only=False, - mindspore_dtype=None, - **config_kwargs, ): - try: - config = CLIPTextConfig.from_pretrained(config_name, **config_kwargs, local_files_only=local_files_only) - except Exception: - raise ValueError( - f"With local_files_only set to {local_files_only}, you must first locally save the configuration in the following path: '{config_name}'." - ) - - text_model = CLIPTextModelWithProjection(config) if has_projection else CLIPTextModel(config) - text_model_dict = {} text_proj_key = prefix + "text_projection" - text_proj_dim = ( - int(checkpoint[text_proj_key].shape[0]) if text_proj_key in checkpoint else LDM_OPEN_CLIP_TEXT_PROJECTION_DIM - ) - text_model_dict["text_model.embeddings.position_ids"] = Parameter( - text_model.text_model.embeddings.position_ids, name="text_model.embeddings.position_ids" - ) + + if text_proj_key in checkpoint: + text_proj_dim = int(checkpoint[text_proj_key].shape[0]) + elif hasattr(text_model.config, "projection_dim"): + text_proj_dim = text_model.config.projection_dim + else: + text_proj_dim = LDM_OPEN_CLIP_TEXT_PROJECTION_DIM keys = list(checkpoint.keys()) keys_to_ignore = SD_2_TEXT_ENCODER_KEYS_TO_IGNORE @@ -1180,285 +1306,179 @@ def create_text_encoder_from_open_clip_checkpoint( ) if key.endswith(".in_proj_weight"): - weight_value = checkpoint[key] + weight_value = checkpoint.get(key) text_model_dict[diffusers_key + ".q_proj.weight"] = Parameter( - weight_value[:text_proj_dim, :], name=diffusers_key + ".q_proj.weight" + weight_value[:text_proj_dim, :].copy(), name=diffusers_key + ".q_proj.weight" ) text_model_dict[diffusers_key + ".k_proj.weight"] = Parameter( - weight_value[text_proj_dim : text_proj_dim * 2, :], name=diffusers_key + ".k_proj.weight" + weight_value[text_proj_dim : text_proj_dim * 2, :].copy(), + name=diffusers_key + ".k_proj.weight", ) text_model_dict[diffusers_key + ".v_proj.weight"] = Parameter( - weight_value[text_proj_dim * 2 :, :], name=diffusers_key + ".v_proj.weight" + weight_value[text_proj_dim * 2 :, :].copy(), name=diffusers_key + ".v_proj.weight" ) elif key.endswith(".in_proj_bias"): - weight_value = checkpoint[key] + weight_value = checkpoint.get(key) text_model_dict[diffusers_key + ".q_proj.bias"] = Parameter( - weight_value[:text_proj_dim], name=diffusers_key + ".q_proj.bias" + weight_value[:text_proj_dim].copy(), name=diffusers_key + ".q_proj.bias" ) text_model_dict[diffusers_key + ".k_proj.bias"] = Parameter( - weight_value[text_proj_dim : text_proj_dim * 2], name=diffusers_key + ".k_proj.bias" + weight_value[text_proj_dim : text_proj_dim * 2].copy(), name=diffusers_key + ".k_proj.bias" ) text_model_dict[diffusers_key + ".v_proj.bias"] = Parameter( - weight_value[text_proj_dim * 2 :], name=diffusers_key + ".v_proj.bias" + weight_value[text_proj_dim * 2 :].copy(), name=diffusers_key + ".v_proj.bias" ) else: - text_model_dict[diffusers_key] = checkpoint[key] - - if not (hasattr(text_model, "embeddings") and hasattr(text_model.embeddings.position_ids)): - text_model_dict.pop("text_model.embeddings.position_ids", None) - text_model_dict_ms = _convert_state_dict(text_model, text_model_dict) - _load_param_into_net(text_model, text_model_dict_ms) - - if mindspore_dtype is not None: - text_model = text_model.to(mindspore_dtype) + text_model_dict[diffusers_key] = checkpoint.get(key) - return text_model + return text_model_dict -def create_diffusers_unet_model_from_ldm( - pipeline_class_name, - original_config, +def create_diffusers_clip_model_from_ldm( + cls, checkpoint, - num_in_channels=None, - upcast_attention=None, - extract_ema=False, - image_size=None, + subfolder="", + config=None, mindspore_dtype=None, - model_type=None, + local_files_only=None, + is_legacy_loading=False, ): - from ..models import UNet2DConditionModel - - if num_in_channels is None: - if pipeline_class_name in [ - "StableDiffusionInpaintPipeline", - "StableDiffusionControlNetInpaintPipeline", - "StableDiffusionXLInpaintPipeline", - "StableDiffusionXLControlNetInpaintPipeline", - ]: - num_in_channels = 9 - - elif pipeline_class_name == "StableDiffusionUpscalePipeline": - num_in_channels = 7 - - else: - num_in_channels = 4 + if config: + config = {"pretrained_model_name_or_path": config} + else: + config = fetch_diffusers_config(checkpoint) - image_size = set_image_size( - pipeline_class_name, original_config, checkpoint, image_size=image_size, model_type=model_type - ) - unet_config = create_unet_diffusers_config(original_config, image_size=image_size) - unet_config["in_channels"] = num_in_channels - if upcast_attention is not None: - unet_config["upcast_attention"] = upcast_attention + # For backwards compatibility + # Older versions of `from_single_file` expected CLIP configs to be placed in their original transformers model repo + # in the cache_dir, rather than in a subfolder of the Diffusers model + if is_legacy_loading: + logger.warning( + ( + "Detected legacy CLIP loading behavior. Please run `from_single_file` with `local_files_only=False once to update " + "the local cache directory with the necessary CLIP model config files. " + "Attempting to load CLIP model from legacy cache directory." + ) + ) - diffusers_format_unet_checkpoint = convert_ldm_unet_checkpoint(checkpoint, unet_config, extract_ema=extract_ema) + if is_clip_model(checkpoint) or is_clip_sdxl_model(checkpoint): + clip_config = "openai/clip-vit-large-patch14" + config["pretrained_model_name_or_path"] = clip_config + subfolder = "" - unet = UNet2DConditionModel(**unet_config) - _load_param_into_net(unet, diffusers_format_unet_checkpoint) + elif is_open_clip_model(checkpoint): + clip_config = "stabilityai/stable-diffusion-2" + config["pretrained_model_name_or_path"] = clip_config + subfolder = "text_encoder" - if mindspore_dtype is not None: - unet = unet.to(mindspore_dtype) + else: + clip_config = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" + config["pretrained_model_name_or_path"] = clip_config + subfolder = "" - return {"unet": unet} + model_config = cls.config_class.from_pretrained(**config, subfolder=subfolder, local_files_only=local_files_only) + model = cls(model_config) + position_embedding_dim = model.text_model.embeddings.position_embedding.embedding_table.shape[-1] -def create_diffusers_vae_model_from_ldm( - pipeline_class_name, - original_config, - checkpoint, - image_size=None, - scaling_factor=None, - mindspore_dtype=None, - model_type=None, -): - # import here to avoid circular imports - from ..models import AutoencoderKL + if is_clip_model(checkpoint): + diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint) - image_size = set_image_size( - pipeline_class_name, original_config, checkpoint, image_size=image_size, model_type=model_type - ) - model_type = infer_model_type(original_config, checkpoint, model_type) + elif ( + is_clip_sdxl_model(checkpoint) + and checkpoint[CHECKPOINT_KEY_NAMES["clip_sdxl"]].shape[-1] == position_embedding_dim + ): + diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint) - if model_type == "Playground": - edm_mean = ( - checkpoint["edm_mean"].to(dtype=mindspore_dtype).tolist() - if mindspore_dtype - else checkpoint["edm_mean"].tolist() - ) - edm_std = ( - checkpoint["edm_std"].to(dtype=mindspore_dtype).tolist() - if mindspore_dtype - else checkpoint["edm_std"].tolist() + elif ( + is_clip_sd3_model(checkpoint) + and checkpoint[CHECKPOINT_KEY_NAMES["clip_sd3"]].shape[-1] == position_embedding_dim + ): + diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint, "text_encoders.clip_l.transformer.") + diffusers_format_checkpoint["text_projection.weight"] = Parameter( + ops.eye(position_embedding_dim), name="text_projection.weight" ) - else: - edm_mean = None - edm_std = None - - vae_config = create_vae_diffusers_config( - original_config, - image_size=image_size, - scaling_factor=scaling_factor, - latents_mean=edm_mean, - latents_std=edm_std, - ) - diffusers_format_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) - vae = AutoencoderKL(**vae_config) - _load_param_into_net(vae, diffusers_format_vae_checkpoint) - - if mindspore_dtype is not None: - vae = vae.to(mindspore_dtype) - - return {"vae": vae} + elif is_open_clip_model(checkpoint): + prefix = "cond_stage_model.model." + diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix) -def create_text_encoders_and_tokenizers_from_ldm( - original_config, - checkpoint, - model_type=None, - local_files_only=False, - mindspore_dtype=None, -): - model_type = infer_model_type(original_config, checkpoint=checkpoint, model_type=model_type) - - if model_type == "FrozenOpenCLIPEmbedder": - config_name = "stabilityai/stable-diffusion-2" - config_kwargs = {"subfolder": "text_encoder"} - - try: - text_encoder = create_text_encoder_from_open_clip_checkpoint( - config_name, - checkpoint, - local_files_only=local_files_only, - mindspore_dtype=mindspore_dtype, - **config_kwargs, - ) - tokenizer = CLIPTokenizer.from_pretrained( - config_name, subfolder="tokenizer", local_files_only=local_files_only - ) - except Exception: - raise ValueError( - f"With local_files_only set to {local_files_only}, you must first locally save the text_encoder in the following path: '{config_name}'." - ) - else: - return {"text_encoder": text_encoder, "tokenizer": tokenizer} - - elif model_type == "FrozenCLIPEmbedder": - try: - config_name = "openai/clip-vit-large-patch14" - text_encoder = create_text_encoder_from_ldm_clip_checkpoint( - config_name, - checkpoint, - local_files_only=local_files_only, - mindspore_dtype=mindspore_dtype, - ) - tokenizer = CLIPTokenizer.from_pretrained(config_name, local_files_only=local_files_only) - - except Exception: - raise ValueError( - f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: '{config_name}'." - ) - else: - return {"text_encoder": text_encoder, "tokenizer": tokenizer} + elif ( + is_open_clip_sdxl_model(checkpoint) + and checkpoint[CHECKPOINT_KEY_NAMES["open_clip_sdxl"]].shape[-1] == position_embedding_dim + ): + prefix = "conditioner.embedders.1.model." + diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix) - elif model_type == "SDXL-Refiner": - config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" - config_kwargs = {"projection_dim": 1280} + elif is_open_clip_sdxl_refiner_model(checkpoint): prefix = "conditioner.embedders.0.model." + diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix) - try: - tokenizer_2 = CLIPTokenizer.from_pretrained(config_name, pad_token="!", local_files_only=local_files_only) - text_encoder_2 = create_text_encoder_from_open_clip_checkpoint( - config_name, - checkpoint, - prefix=prefix, - has_projection=True, - local_files_only=local_files_only, - mindspore_dtype=mindspore_dtype, - **config_kwargs, - ) - except Exception: - raise ValueError( - f"With local_files_only set to {local_files_only}, you must first locally save the text_encoder_2 and tokenizer_2 \ - in the following path: {config_name} with `pad_token` set to '!'." - ) + elif ( + is_open_clip_sd3_model(checkpoint) + and checkpoint[CHECKPOINT_KEY_NAMES["open_clip_sd3"]].shape[-1] == position_embedding_dim + ): + diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint, "text_encoders.clip_g.transformer.") - else: - return { - "text_encoder": None, - "tokenizer": None, - "tokenizer_2": tokenizer_2, - "text_encoder_2": text_encoder_2, - } + else: + raise ValueError("The provided checkpoint does not seem to contain a valid CLIP model.") - elif model_type in ["SDXL", "Playground"]: - try: - config_name = "openai/clip-vit-large-patch14" - tokenizer = CLIPTokenizer.from_pretrained(config_name, local_files_only=local_files_only) - text_encoder = create_text_encoder_from_ldm_clip_checkpoint( - config_name, checkpoint, local_files_only=local_files_only, mindspore_dtype=mindspore_dtype - ) + diffusers_format_checkpoint = _convert_state_dict(model, diffusers_format_checkpoint) + _, unexpected_keys = _load_param_into_net(model, diffusers_format_checkpoint, mindspore_dtype) - except Exception: - raise ValueError( - f"With local_files_only set to {local_files_only}, you must first locally save the text_encoder and tokenizer \ - in the following path: 'openai/clip-vit-large-patch14'." - ) + if model._keys_to_ignore_on_load_unexpected is not None: + for pat in model._keys_to_ignore_on_load_unexpected: + unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] - try: - config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" - config_kwargs = {"projection_dim": 1280} - prefix = "conditioner.embedders.1.model." - tokenizer_2 = CLIPTokenizer.from_pretrained(config_name, pad_token="!", local_files_only=local_files_only) - text_encoder_2 = create_text_encoder_from_open_clip_checkpoint( - config_name, - checkpoint, - prefix=prefix, - has_projection=True, - local_files_only=local_files_only, - mindspore_dtype=mindspore_dtype, - **config_kwargs, - ) - except Exception: - raise ValueError( - f"With local_files_only set to {local_files_only}, you must first locally save the text_encoder_2 and tokenizer_2 \ - in the following path: {config_name} with `pad_token` set to '!'." - ) + if len(unexpected_keys) > 0: + logger.warning( + f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}" + ) - return { - "tokenizer": tokenizer, - "text_encoder": text_encoder, - "tokenizer_2": tokenizer_2, - "text_encoder_2": text_encoder_2, - } + model.set_train(False) - return + return model -def create_scheduler_from_ldm( - pipeline_class_name, - original_config, +def _legacy_load_scheduler( + cls, checkpoint, - prediction_type=None, - scheduler_type="ddim", - model_type=None, + component_name, + original_config=None, + **kwargs, ): - scheduler_config = get_default_scheduler_config() - model_type = infer_model_type(original_config, checkpoint=checkpoint, model_type=model_type) + scheduler_type = kwargs.get("scheduler_type", None) + prediction_type = kwargs.get("prediction_type", None) + + if scheduler_type is not None: + deprecation_message = ( + "Please pass an instance of a Scheduler object directly to the `scheduler` argument in `from_single_file`." + ) + deprecate("scheduler_type", "1.0.0", deprecation_message) + + if prediction_type is not None: + deprecation_message = ( + "Please configure an instance of a Scheduler with the appropriate `prediction_type` " + "and pass the object directly to the `scheduler` argument in `from_single_file`." + ) + deprecate("prediction_type", "1.0.0", deprecation_message) + + scheduler_config = SCHEDULER_DEFAULT_CONFIG + model_type = infer_diffusers_model_type(checkpoint=checkpoint) global_step = checkpoint["global_step"] if "global_step" in checkpoint else None - num_train_timesteps = getattr(original_config["model"]["params"], "timesteps", None) or 1000 + if original_config: + num_train_timesteps = getattr(original_config["model"]["params"], "timesteps", 1000) + else: + num_train_timesteps = 1000 + scheduler_config["num_train_timesteps"] = num_train_timesteps - if ( - "parameterization" in original_config["model"]["params"] - and original_config["model"]["params"]["parameterization"] == "v" - ): + if model_type == "v2": if prediction_type is None: - # NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"` - # as it relies on a brittle global step parameter here + # NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"` # as it relies on a brittle global step parameter here prediction_type = "epsilon" if global_step == 875000 else "v_prediction" else: @@ -1466,20 +1486,44 @@ def create_scheduler_from_ldm( scheduler_config["prediction_type"] = prediction_type - if model_type in ["SDXL", "SDXL-Refiner"]: + if model_type in ["xl_base", "xl_refiner"]: scheduler_type = "euler" - elif model_type == "Playground": + elif model_type == "playground": scheduler_type = "edm_dpm_solver_multistep" else: - beta_start = original_config["model"]["params"].get("linear_start", 0.02) - beta_end = original_config["model"]["params"].get("linear_end", 0.085) + if original_config: + beta_start = original_config["model"]["params"].get("linear_start") + beta_end = original_config["model"]["params"].get("linear_end") + + else: + beta_start = 0.02 + beta_end = 0.085 + scheduler_config["beta_start"] = beta_start scheduler_config["beta_end"] = beta_end scheduler_config["beta_schedule"] = "scaled_linear" scheduler_config["clip_sample"] = False scheduler_config["set_alpha_to_one"] = False - if scheduler_type == "pndm": + # to deal with an edge case StableDiffusionUpscale pipeline has two schedulers + if component_name == "low_res_scheduler": + return cls.from_config( + { + "beta_end": 0.02, + "beta_schedule": "scaled_linear", + "beta_start": 0.0001, + "clip_sample": True, + "num_train_timesteps": 1000, + "prediction_type": "epsilon", + "trained_betas": None, + "variance_type": "fixed_small", + } + ) + + if scheduler_type is None: + return cls.from_config(scheduler_config) + + elif scheduler_type == "pndm": scheduler_config["skip_prk_steps"] = True scheduler = PNDMScheduler.from_config(scheduler_config) @@ -1524,24 +1568,290 @@ def create_scheduler_from_ldm( else: raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!") - if pipeline_class_name == "StableDiffusionUpscalePipeline": - scheduler = DDIMScheduler.from_pretrained("stabilityai/stable-diffusion-x4-upscaler", subfolder="scheduler") - low_res_scheduler = DDPMScheduler.from_pretrained( - "stabilityai/stable-diffusion-x4-upscaler", subfolder="low_res_scheduler" + return scheduler + + +def _legacy_load_clip_tokenizer(cls, checkpoint, config=None, local_files_only=False): + if config: + config = {"pretrained_model_name_or_path": config} + else: + config = fetch_diffusers_config(checkpoint) + + if is_clip_model(checkpoint) or is_clip_sdxl_model(checkpoint): + clip_config = "openai/clip-vit-large-patch14" + config["pretrained_model_name_or_path"] = clip_config + subfolder = "" + + elif is_open_clip_model(checkpoint): + clip_config = "stabilityai/stable-diffusion-2" + config["pretrained_model_name_or_path"] = clip_config + subfolder = "tokenizer" + + else: + clip_config = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" + config["pretrained_model_name_or_path"] = clip_config + subfolder = "" + + tokenizer = cls.from_pretrained(**config, subfolder=subfolder, local_files_only=local_files_only) + + return tokenizer + + +def _legacy_load_safety_checker(local_files_only, mindspore_dtype): + # Support for loading safety checker components using the deprecated + # `load_safety_checker` argument. + + from ..pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker + + feature_extractor = AutoImageProcessor.from_pretrained( + "CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only, mindspore_dtype=mindspore_dtype + ) + safety_checker = StableDiffusionSafetyChecker.from_pretrained( + "CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only, mindspore_dtype=mindspore_dtype + ) + + return {"safety_checker": safety_checker, "feature_extractor": feature_extractor} + + +# in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale; +# while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation +def swap_scale_shift(weight, dim): + shift, scale = weight.chunk(2, axis=0) + new_weight = ops.cat([scale, shift], axis=0) + return new_weight + + +def convert_sd3_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): + converted_state_dict = {} + keys = list(checkpoint.keys()) + for k in keys: + if "model.diffusion_model." in k: + checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k) + + num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "joint_blocks" in k))[-1] + 1 # noqa: C401 + caption_projection_dim = 1536 + + # Positional and patch embeddings. + converted_state_dict["pos_embed.pos_embed"] = checkpoint.pop("pos_embed") + converted_state_dict["pos_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight") + converted_state_dict["pos_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias") + + # Timestep embeddings. + converted_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop( + "t_embedder.mlp.0.weight" + ) + converted_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias") + converted_state_dict["time_text_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop( + "t_embedder.mlp.2.weight" + ) + converted_state_dict["time_text_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias") + + # Context projections. + converted_state_dict["context_embedder.weight"] = checkpoint.pop("context_embedder.weight") + converted_state_dict["context_embedder.bias"] = checkpoint.pop("context_embedder.bias") + + # Pooled context projection. + converted_state_dict["time_text_embed.text_embedder.linear_1.weight"] = checkpoint.pop("y_embedder.mlp.0.weight") + converted_state_dict["time_text_embed.text_embedder.linear_1.bias"] = checkpoint.pop("y_embedder.mlp.0.bias") + converted_state_dict["time_text_embed.text_embedder.linear_2.weight"] = checkpoint.pop("y_embedder.mlp.2.weight") + converted_state_dict["time_text_embed.text_embedder.linear_2.bias"] = checkpoint.pop("y_embedder.mlp.2.bias") + + # Transformer blocks 🎸. + for i in range(num_layers): + # Q, K, V + sample_q, sample_k, sample_v = ops.chunk(checkpoint.pop(f"joint_blocks.{i}.x_block.attn.qkv.weight"), 3, axis=0) + context_q, context_k, context_v = ops.chunk( + checkpoint.pop(f"joint_blocks.{i}.context_block.attn.qkv.weight"), 3, axis=0 + ) + sample_q_bias, sample_k_bias, sample_v_bias = ops.chunk( + checkpoint.pop(f"joint_blocks.{i}.x_block.attn.qkv.bias"), 3, axis=0 + ) + context_q_bias, context_k_bias, context_v_bias = ops.chunk( + checkpoint.pop(f"joint_blocks.{i}.context_block.attn.qkv.bias"), 3, axis=0 ) - return { - "scheduler": scheduler, - "low_res_scheduler": low_res_scheduler, - } + converted_state_dict[f"transformer_blocks.{i}.attn.to_q.weight"] = Parameter( + ops.cat([sample_q]), name=f"transformer_blocks.{i}.attn.to_q.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.attn.to_q.bias"] = Parameter( + ops.cat([sample_q_bias]), name=f"transformer_blocks.{i}.attn.to_q.bias" + ) + converted_state_dict[f"transformer_blocks.{i}.attn.to_k.weight"] = Parameter( + ops.cat([sample_k]), name=f"transformer_blocks.{i}.attn.to_k.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.attn.to_k.bias"] = Parameter( + ops.cat([sample_k_bias]), name=f"transformer_blocks.{i}.attn.to_k.bias" + ) + converted_state_dict[f"transformer_blocks.{i}.attn.to_v.weight"] = Parameter( + ops.cat([sample_v]), name=f"transformer_blocks.{i}.attn.to_v.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.attn.to_v.bias"] = Parameter( + ops.cat([sample_v_bias]), name=f"transformer_blocks.{i}.attn.to_v.bias" + ) + + converted_state_dict[f"transformer_blocks.{i}.attn.add_q_proj.weight"] = Parameter( + ops.cat([context_q]), name=f"transformer_blocks.{i}.attn.add_q_proj.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.attn.add_q_proj.bias"] = Parameter( + ops.cat([context_q_bias]), name=f"transformer_blocks.{i}.attn.add_q_proj.bias" + ) + converted_state_dict[f"transformer_blocks.{i}.attn.add_k_proj.weight"] = Parameter( + ops.cat([context_k]), name=f"transformer_blocks.{i}.attn.add_k_proj.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.attn.add_k_proj.bias"] = Parameter( + ops.cat([context_k_bias]), name=f"transformer_blocks.{i}.attn.add_k_proj.bias" + ) + converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.weight"] = Parameter( + ops.cat([context_v]), name=f"transformer_blocks.{i}.attn.add_v_proj.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.bias"] = Parameter( + ops.cat([context_v_bias]), name=f"transformer_blocks.{i}.attn.add_v_proj.bias" + ) - return {"scheduler": scheduler} + # output projections. + converted_state_dict[f"transformer_blocks.{i}.attn.to_out.0.weight"] = checkpoint.pop( + f"joint_blocks.{i}.x_block.attn.proj.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.attn.to_out.0.bias"] = checkpoint.pop( + f"joint_blocks.{i}.x_block.attn.proj.bias" + ) + if not (i == num_layers - 1): + converted_state_dict[f"transformer_blocks.{i}.attn.to_add_out.weight"] = checkpoint.pop( + f"joint_blocks.{i}.context_block.attn.proj.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.attn.to_add_out.bias"] = checkpoint.pop( + f"joint_blocks.{i}.context_block.attn.proj.bias" + ) + + # norms. + converted_state_dict[f"transformer_blocks.{i}.norm1.linear.weight"] = checkpoint.pop( + f"joint_blocks.{i}.x_block.adaLN_modulation.1.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.norm1.linear.bias"] = checkpoint.pop( + f"joint_blocks.{i}.x_block.adaLN_modulation.1.bias" + ) + if not (i == num_layers - 1): + converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.weight"] = checkpoint.pop( + f"joint_blocks.{i}.context_block.adaLN_modulation.1.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.bias"] = checkpoint.pop( + f"joint_blocks.{i}.context_block.adaLN_modulation.1.bias" + ) + else: + converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.weight"] = Parameter( + swap_scale_shift( + checkpoint.pop(f"joint_blocks.{i}.context_block.adaLN_modulation.1.weight"), + dim=caption_projection_dim, + ), + name=f"transformer_blocks.{i}.norm1_context.linear.weight", + ) + converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.bias"] = Parameter( + swap_scale_shift( + checkpoint.pop(f"joint_blocks.{i}.context_block.adaLN_modulation.1.bias"), + dim=caption_projection_dim, + ), + name=f"transformer_blocks.{i}.norm1_context.linear.bias", + ) + + # ffs. + converted_state_dict[f"transformer_blocks.{i}.ff.net.0.proj.weight"] = checkpoint.pop( + f"joint_blocks.{i}.x_block.mlp.fc1.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.ff.net.0.proj.bias"] = checkpoint.pop( + f"joint_blocks.{i}.x_block.mlp.fc1.bias" + ) + converted_state_dict[f"transformer_blocks.{i}.ff.net.2.weight"] = checkpoint.pop( + f"joint_blocks.{i}.x_block.mlp.fc2.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.ff.net.2.bias"] = checkpoint.pop( + f"joint_blocks.{i}.x_block.mlp.fc2.bias" + ) + if not (i == num_layers - 1): + converted_state_dict[f"transformer_blocks.{i}.ff_context.net.0.proj.weight"] = checkpoint.pop( + f"joint_blocks.{i}.context_block.mlp.fc1.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.ff_context.net.0.proj.bias"] = checkpoint.pop( + f"joint_blocks.{i}.context_block.mlp.fc1.bias" + ) + converted_state_dict[f"transformer_blocks.{i}.ff_context.net.2.weight"] = checkpoint.pop( + f"joint_blocks.{i}.context_block.mlp.fc2.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.ff_context.net.2.bias"] = checkpoint.pop( + f"joint_blocks.{i}.context_block.mlp.fc2.bias" + ) + + # Final blocks. + converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight") + converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias") + converted_state_dict["norm_out.linear.weight"] = Parameter( + swap_scale_shift(checkpoint.pop("final_layer.adaLN_modulation.1.weight"), dim=caption_projection_dim), + name="norm_out.linear.weight", + ) + converted_state_dict["norm_out.linear.bias"] = Parameter( + swap_scale_shift(checkpoint.pop("final_layer.adaLN_modulation.1.bias"), dim=caption_projection_dim), + name="norm_out.linear.bias", + ) + return converted_state_dict -def _load_param_into_net(model, state_dict): + +def is_t5_in_single_file(checkpoint): + if "text_encoders.t5xxl.transformer.shared.weight" in checkpoint: + return True + + return False + + +def convert_sd3_t5_checkpoint_to_diffusers(checkpoint): + keys = list(checkpoint.keys()) + text_model_dict = {} + + remove_prefixes = ["text_encoders.t5xxl.transformer."] + + for key in keys: + for prefix in remove_prefixes: + if key.startswith(prefix): + diffusers_key = key.replace(prefix, "") + text_model_dict[diffusers_key] = checkpoint.get(key) + + return text_model_dict + + +def create_diffusers_t5_model_from_checkpoint( + cls, + checkpoint, + subfolder="", + config=None, + mindspore_dtype=None, + local_files_only=None, +): + if config: + config = {"pretrained_model_name_or_path": config} + else: + config = fetch_diffusers_config(checkpoint) + + model_config = cls.config_class.from_pretrained(**config, subfolder=subfolder, local_files_only=local_files_only) + model = cls(model_config) + + diffusers_format_checkpoint = convert_sd3_t5_checkpoint_to_diffusers(checkpoint) + diffusers_format_checkpoint = _convert_state_dict(model, diffusers_format_checkpoint) + + _load_param_into_net(model, diffusers_format_checkpoint, mindspore_dtype) + return model + + +def _load_param_into_net(model, state_dict, mindspore_dtype=None): model_dtype = next(iter(model.get_parameters())).dtype - for _, v in state_dict.items(): - v.set_dtype(model_dtype) + state_dict_dtype = next(iter(state_dict.values())).dtype + mindspore_dtype = mindspore_dtype or model_dtype + + if model_dtype != mindspore_dtype: + for p in model.get_parameters(): + p.set_dtype(mindspore_dtype) + + if state_dict_dtype != mindspore_dtype: + for v in state_dict.values(): + v.set_dtype(mindspore_dtype) + _, ckpt_not_load = ms.load_param_into_net(model, state_dict, strict_load=True) - if len(ckpt_not_load) > 0: - logger.warning("checkpoint params not loaded: {}".format([p for p in ckpt_not_load])) + return _, ckpt_not_load diff --git a/mindone/diffusers/loaders/textual_inversion.py b/mindone/diffusers/loaders/textual_inversion.py index 981986e527..92683dd729 100644 --- a/mindone/diffusers/loaders/textual_inversion.py +++ b/mindone/diffusers/loaders/textual_inversion.py @@ -33,7 +33,7 @@ def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs): cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) - resume_download = kwargs.pop("resume_download", False) + resume_download = kwargs.pop("resume_download", None) proxies = kwargs.pop("proxies", None) local_files_only = kwargs.pop("local_files_only", None) token = kwargs.pop("token", None) @@ -308,9 +308,9 @@ def load_textual_inversion( force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. - resume_download (`bool`, *optional*, defaults to `False`): - Whether or not to resume downloading the model weights and configuration files. If set to `False`, any - incompletely downloaded files are deleted. + resume_download: + Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1 + of Diffusers. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. @@ -443,20 +443,35 @@ def unload_textual_inversion( # Example 3: unload from SDXL pipeline = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0") - embedding_path = hf_hub_download(repo_id="linoyts/web_y2k", filename="web_y2k_emb.safetensors", repo_type="model") + embedding_path = hf_hub_download( + repo_id="linoyts/web_y2k", filename="web_y2k_emb.safetensors", repo_type="model" + ) # load embeddings to the text encoders state_dict = load_file(embedding_path) # load embeddings of text_encoder 1 (CLIP ViT-L/14) - pipeline.load_textual_inversion(state_dict["clip_l"], token=["", ""], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer) + pipeline.load_textual_inversion( + state_dict["clip_l"], + token=["", ""], + text_encoder=pipeline.text_encoder, + tokenizer=pipeline.tokenizer, + ) # load embeddings of text_encoder 2 (CLIP ViT-G/14) - pipeline.load_textual_inversion(state_dict["clip_g"], token=["", ""], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2) + pipeline.load_textual_inversion( + state_dict["clip_g"], + token=["", ""], + text_encoder=pipeline.text_encoder_2, + tokenizer=pipeline.tokenizer_2, + ) # Unload explicitly from both text encoders abd tokenizers - pipeline.unload_textual_inversion(tokens=["", ""], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer) - pipeline.unload_textual_inversion(tokens=["", ""], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2) - + pipeline.unload_textual_inversion( + tokens=["", ""], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer + ) + pipeline.unload_textual_inversion( + tokens=["", ""], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2 + ) ``` """ diff --git a/mindone/diffusers/loaders/unet.py b/mindone/diffusers/loaders/unet.py index 2020c637c2..82a8cba397 100644 --- a/mindone/diffusers/loaders/unet.py +++ b/mindone/diffusers/loaders/unet.py @@ -14,43 +14,41 @@ import inspect import os from functools import partial +from pathlib import Path from typing import Callable, Dict, List, Optional, Union from huggingface_hub.utils import validate_hf_hub_args import mindspore as ms -from mindone.safetensors.mindspore import load_file +from mindone.safetensors.mindspore import load_file, save_file from ..models.embeddings import ( ImageProjection, + IPAdapterFaceIDImageProjection, + IPAdapterFaceIDPlusImageProjection, IPAdapterFullImageProjection, IPAdapterPlusImageProjection, MultiIPAdapterImageProjection, ) from ..utils import ( _get_model_file, + convert_unet_state_dict_to_peft, delete_adapter_layers, + get_adapter_name, + get_peft_kwargs, + is_peft_version, logging, set_adapter_layers, set_weights_and_activate_adapters, ) -from .single_file_utils import ( - _load_param_into_net, - convert_stable_cascade_unet_single_file_to_diffusers, - infer_stable_cascade_single_file_config, - load_single_file_model_checkpoint, -) +from .lora import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, TEXT_ENCODER_NAME, UNET_NAME +from .single_file_utils import _load_param_into_net +from .unet_loader_utils import _maybe_expand_lora_scales logger = logging.get_logger(__name__) -TEXT_ENCODER_NAME = "text_encoder" -UNET_NAME = "unet" - -LORA_WEIGHT_NAME = "pytorch_lora_weights.bin" -LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors" - CUSTOM_DIFFUSION_WEIGHT_NAME = "pytorch_custom_diffusion_weights.bin" CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensors" @@ -69,7 +67,8 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict Load pretrained attention processor layers into [`UNet2DConditionModel`]. Attention processor layers have to be defined in [`attention_processor.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py) - and be a `torch.nn.Module` class. + and be a `torch.nn.Module` class. Currently supported: LoRA, Custom Diffusion. For LoRA, one must install + `peft`: `pip install -U peft`. Parameters: pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): @@ -88,9 +87,9 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. - resume_download (`bool`, *optional*, defaults to `False`): - Whether or not to resume downloading the model weights and configuration files. If set to `False`, any - incompletely downloaded files are deleted. + resume_download: + Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1 + of Diffusers. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. @@ -105,10 +104,15 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. - mirror (`str`, *optional*): - Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not - guarantee the timeliness or safety of the source, and you should refer to the mirror site for more - information. + network_alphas (`Dict[str, float]`): + The value of the network alpha used for stable learning and preventing underflow. This value has the + same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this + link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning). + adapter_name (`str`, *optional*, defaults to None): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + weight_name (`str`, *optional*, defaults to None): + Name of the serialized state dict file. Example: @@ -126,7 +130,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict """ cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) - resume_download = kwargs.pop("resume_download", False) + resume_download = kwargs.pop("resume_download", None) proxies = kwargs.pop("proxies", None) local_files_only = kwargs.pop("local_files_only", None) token = kwargs.pop("token", None) @@ -134,9 +138,10 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) - # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. - # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning + adapter_name = kwargs.pop("adapter_name", None) + _pipeline = kwargs.pop("_pipeline", None) network_alphas = kwargs.pop("network_alphas", None) + allow_pickle = False _pipeline = kwargs.pop("_pipeline", None) # noqa: F841 @@ -200,10 +205,95 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict state_dict = pretrained_model_name_or_path_or_dict is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys()) + is_lora = all(("lora" in k or k.endswith(".alpha")) for k in state_dict.keys()) + if is_custom_diffusion: raise NotImplementedError("CustomDiffusionAttnProcessor is not yet supported.") - # In fact, we have nothing to do as loading the adapter weights is already handled above - # by `set_peft_model_state_dict` on the Unet + elif is_lora: + is_model_cpu_offload, is_sequential_cpu_offload = self._process_lora( + state_dict=state_dict, + unet_identifier_key=self.unet_name, + network_alphas=network_alphas, + adapter_name=adapter_name, + _pipeline=_pipeline, + ) + else: + raise ValueError( + f"{model_file} does not seem to be in the correct format expected by Custom Diffusion training." + ) + + def _process_lora(self, state_dict, unet_identifier_key, network_alphas, adapter_name, _pipeline): + # This method does the following things: + # 1. Filters the `state_dict` with keys matching `unet_identifier_key` when using the non-legacy + # format. For legacy format no filtering is applied. + # 2. Converts the `state_dict` to the `peft` compatible format. + # 3. Creates a `LoraConfig` and then injects the converted `state_dict` into the UNet per the + # `LoraConfig` specs. + # 4. It also reports if the underlying `_pipeline` has any kind of offloading inside of it. + from mindone.diffusers._peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict + + keys = list(state_dict.keys()) + + unet_keys = [k for k in keys if k.startswith(unet_identifier_key)] + unet_state_dict = {k.replace(f"{unet_identifier_key}.", ""): v for k, v in state_dict.items() if k in unet_keys} + + if network_alphas is not None: + alpha_keys = [k for k in network_alphas.keys() if k.startswith(unet_identifier_key)] + network_alphas = { + k.replace(f"{unet_identifier_key}.", ""): v for k, v in network_alphas.items() if k in alpha_keys + } + + is_model_cpu_offload = False + is_sequential_cpu_offload = False + state_dict_to_be_used = unet_state_dict if len(unet_state_dict) > 0 else state_dict + + if len(state_dict_to_be_used) > 0: + if adapter_name in getattr(self, "peft_config", {}): + raise ValueError( + f"Adapter name {adapter_name} already in use in the Unet - please select a new adapter name." + ) + + state_dict = convert_unet_state_dict_to_peft(state_dict_to_be_used) + + if network_alphas is not None: + # The alphas state dict have the same structure as Unet, thus we convert it to peft format using + # `convert_unet_state_dict_to_peft` method. + network_alphas = convert_unet_state_dict_to_peft(network_alphas) + + rank = {} + for key, val in state_dict.items(): + if "lora_B" in key: + rank[key] = val.shape[1] + + lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=True) + if "use_dora" in lora_config_kwargs: + if lora_config_kwargs["use_dora"]: + if is_peft_version("<", "0.9.0"): + raise ValueError( + "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." + ) + else: + if is_peft_version("<", "0.9.0"): + lora_config_kwargs.pop("use_dora") + lora_config = LoraConfig(**lora_config_kwargs) + + # adapter_name + if adapter_name is None: + adapter_name = get_adapter_name(self) + + inject_adapter_in_model(lora_config, self, adapter_name=adapter_name) + incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name) + + if incompatible_keys is not None: + # check only for unexpected keys + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + logger.warning( + f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " + f" {unexpected_keys}. " + ) + + return is_model_cpu_offload, is_sequential_cpu_offload def save_attn_procs( self, @@ -214,7 +304,74 @@ def save_attn_procs( safe_serialization: bool = True, **kwargs, ): - raise NotImplementedError(f"{self.__class__.__name__}.save_attn_procs is not yet supported.") + r""" + Save attention processor layers to a directory so that it can be reloaded with the + [`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`] method. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to save an attention processor to (will be created if it doesn't exist). + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful during distributed training and you + need to call this function on all processes. In this case, set `is_main_process=True` only on the main + process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful during distributed training when you need to + replace `torch.save` with another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the model using `safetensors` or with `pickle`. + + Example: + + ```py + import mindspore + from mindone.diffusers import DiffusionPipeline + + pipeline = DiffusionPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", + mindspore_dtype=mindspore.float16, + ) + pipeline.unet.load_attn_procs("path-to-save-model", weight_name="lora_diffusion_weights.safetensors") + pipeline.unet.save_attn_procs("path-to-save-model", weight_name="lora_diffusion_weights.safetensors") + ``` + """ + from ..models.attention_processor import CustomDiffusionAttnProcessor + + if os.path.isfile(save_directory): + logger.error(f"Provided path ({save_directory}) should be a directory, not a file") + return + + is_custom_diffusion = any( + isinstance(x, CustomDiffusionAttnProcessor) for (_, x) in self.attn_processors.items() + ) + if is_custom_diffusion: + raise NotImplementedError( + f"is_custom_diffusion is not yet supported in {self.__class__.__name__}.save_attn_procs ." + ) + else: + from mindone.diffusers._peft.utils import get_peft_model_state_dict + + state_dict = get_peft_model_state_dict(self) + + if save_function is None: + if safe_serialization: + save_function = partial(save_file, metadata={"format": "np"}) + else: + save_function = ms.save_checkpoint + + os.makedirs(save_directory, exist_ok=True) + + if weight_name is None: + if safe_serialization: + weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE if is_custom_diffusion else LORA_WEIGHT_NAME_SAFE + else: + weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME if is_custom_diffusion else LORA_WEIGHT_NAME + + # Save the model + save_path = Path(save_directory, weight_name).as_posix() + save_function(state_dict, save_path) + logger.info(f"Model weights saved in {save_path}") def fuse_lora(self, lora_scale=1.0, safe_fusing=False, adapter_names=None): self.lora_scale = lora_scale @@ -249,10 +406,17 @@ def _unfuse_lora_apply(self, module): if isinstance(module, BaseTunerLayer): module.unmerge() + def unload_lora(self): + from ..utils import recurse_remove_peft_layers + + recurse_remove_peft_layers(self) + if hasattr(self, "peft_config"): + del self.peft_config + def set_adapters( self, adapter_names: Union[List[str], str], - weights: Optional[Union[List[float], float]] = None, + weights: Optional[Union[float, Dict, List[float], List[Dict], List[None]]] = None, ): """ Set the currently active adapters for use in the UNet. @@ -282,9 +446,9 @@ def set_adapters( """ adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names - if weights is None: - weights = [1.0] * len(adapter_names) - elif isinstance(weights, float): + # Expand weights into a list, one entry per adapter + # examples for e.g. 2 adapters: [{...}, 7] -> [7,7] ; None -> [None, None] + if not isinstance(weights, list): weights = [weights] * len(adapter_names) if len(adapter_names) != len(weights): @@ -292,6 +456,13 @@ def set_adapters( f"Length of adapter names {len(adapter_names)} is not equal to the length of their weights {len(weights)}." ) + # Set None values to default of 1.0 + # e.g. [{...}, 7] -> [{...}, 7] ; [None, None] -> [1.0, 1.0] + weights = [w if w is not None else 1.0 for w in weights] + + # e.g. [{...}, 7] -> [{expanded dict...}, 7] + weights = _maybe_expand_lora_scales(self, weights) + set_weights_and_activate_adapters(self, adapter_names, weights) def disable_lora(self): @@ -403,6 +574,91 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict): diffusers_name = diffusers_name.replace("proj.2", "ff.net.2") diffusers_name = diffusers_name.replace("proj.3", "norm") updated_state_dict[diffusers_name] = value + elif "perceiver_resampler.proj_in.weight" in state_dict: + # IP-Adapter Face ID Plus + id_embeddings_dim = state_dict["proj.0.weight"].shape[1] + embed_dims = state_dict["perceiver_resampler.proj_in.weight"].shape[0] + hidden_dims = state_dict["perceiver_resampler.proj_in.weight"].shape[1] + output_dims = state_dict["perceiver_resampler.proj_out.weight"].shape[0] + heads = state_dict["perceiver_resampler.layers.0.0.to_q.weight"].shape[0] // 64 + + image_projection = IPAdapterFaceIDPlusImageProjection( + embed_dims=embed_dims, + output_dims=output_dims, + hidden_dims=hidden_dims, + heads=heads, + id_embeddings_dim=id_embeddings_dim, + ) + + for key, value in state_dict.items(): + diffusers_name = key.replace("perceiver_resampler.", "") + diffusers_name = diffusers_name.replace("0.to", "attn.to") + diffusers_name = diffusers_name.replace("0.1.0.", "0.ff.0.") + diffusers_name = diffusers_name.replace("0.1.1.weight", "0.ff.1.net.0.proj.weight") + diffusers_name = diffusers_name.replace("0.1.3.weight", "0.ff.1.net.2.weight") + diffusers_name = diffusers_name.replace("1.1.0.", "1.ff.0.") + diffusers_name = diffusers_name.replace("1.1.1.weight", "1.ff.1.net.0.proj.weight") + diffusers_name = diffusers_name.replace("1.1.3.weight", "1.ff.1.net.2.weight") + diffusers_name = diffusers_name.replace("2.1.0.", "2.ff.0.") + diffusers_name = diffusers_name.replace("2.1.1.weight", "2.ff.1.net.0.proj.weight") + diffusers_name = diffusers_name.replace("2.1.3.weight", "2.ff.1.net.2.weight") + diffusers_name = diffusers_name.replace("3.1.0.", "3.ff.0.") + diffusers_name = diffusers_name.replace("3.1.1.weight", "3.ff.1.net.0.proj.weight") + diffusers_name = diffusers_name.replace("3.1.3.weight", "3.ff.1.net.2.weight") + diffusers_name = diffusers_name.replace("layers.0.0", "layers.0.ln0") + diffusers_name = diffusers_name.replace("layers.0.1", "layers.0.ln1") + diffusers_name = diffusers_name.replace("layers.1.0", "layers.1.ln0") + diffusers_name = diffusers_name.replace("layers.1.1", "layers.1.ln1") + diffusers_name = diffusers_name.replace("layers.2.0", "layers.2.ln0") + diffusers_name = diffusers_name.replace("layers.2.1", "layers.2.ln1") + diffusers_name = diffusers_name.replace("layers.3.0", "layers.3.ln0") + diffusers_name = diffusers_name.replace("layers.3.1", "layers.3.ln1") + + if "norm1" in diffusers_name: + updated_state_dict[diffusers_name.replace("0.norm1", "0")] = value + elif "norm2" in diffusers_name: + updated_state_dict[diffusers_name.replace("0.norm2", "1")] = value + elif "to_kv" in diffusers_name: + v_chunk = value.chunk(2, axis=0) + updated_state_dict[diffusers_name.replace("to_kv", "to_k")] = ms.Parameter( + v_chunk[0], name=diffusers_name.replace("to_kv", "to_k") + ) + updated_state_dict[diffusers_name.replace("to_kv", "to_v")] = ms.Parameter( + v_chunk[1], name=diffusers_name.replace("to_kv", "to_v") + ) + elif "to_out" in diffusers_name: + updated_state_dict[diffusers_name.replace("to_out", "to_out.0")] = value + elif "proj.0.weight" == diffusers_name: + updated_state_dict["proj.net.0.proj.weight"] = value + elif "proj.0.bias" == diffusers_name: + updated_state_dict["proj.net.0.proj.bias"] = value + elif "proj.2.weight" == diffusers_name: + updated_state_dict["proj.net.2.weight"] = value + elif "proj.2.bias" == diffusers_name: + updated_state_dict["proj.net.2.bias"] = value + else: + updated_state_dict[diffusers_name] = value + + elif "norm.weight" in state_dict: + # IP-Adapter Face ID + id_embeddings_dim_in = state_dict["proj.0.weight"].shape[1] + id_embeddings_dim_out = state_dict["proj.0.weight"].shape[0] + multiplier = id_embeddings_dim_out // id_embeddings_dim_in + norm_layer = "norm.weight" + cross_attention_dim = state_dict[norm_layer].shape[0] + num_tokens = state_dict["proj.2.weight"].shape[0] // cross_attention_dim + + image_projection = IPAdapterFaceIDImageProjection( + cross_attention_dim=cross_attention_dim, + image_embed_dim=id_embeddings_dim_in, + mult=multiplier, + num_tokens=num_tokens, + ) + + for key, value in state_dict.items(): + diffusers_name = key.replace("proj.0", "ff.net.0.proj") + diffusers_name = diffusers_name.replace("proj.2", "ff.net.2") + updated_state_dict[diffusers_name] = value else: # IP-Adapter Plus @@ -410,7 +666,12 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict): embed_dims = state_dict["proj_in.weight"].shape[1] output_dims = state_dict["proj_out.weight"].shape[0] hidden_dims = state_dict["latents"].shape[2] - heads = state_dict["layers.0.0.to_q.weight"].shape[0] // 64 + attn_key_present = any("attn" in k for k in state_dict) + heads = ( + state_dict["layers.0.attn.to_q.weight"].shape[0] // 64 + if attn_key_present + else state_dict["layers.0.0.to_q.weight"].shape[0] // 64 + ) image_projection = IPAdapterPlusImageProjection( embed_dims=embed_dims, @@ -422,22 +683,54 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict): for key, value in state_dict.items(): diffusers_name = key.replace("0.to", "2.to") - diffusers_name = diffusers_name.replace("1.0.weight", "3.0.weight") - diffusers_name = diffusers_name.replace("1.0.bias", "3.0.bias") - diffusers_name = diffusers_name.replace("1.1.weight", "3.1.net.0.proj.weight") - diffusers_name = diffusers_name.replace("1.3.weight", "3.1.net.2.weight") - if "norm1" in diffusers_name: - updated_state_dict[diffusers_name.replace("0.norm1", "0")] = value - elif "norm2" in diffusers_name: - updated_state_dict[diffusers_name.replace("0.norm2", "1")] = value - elif "to_kv" in diffusers_name: - v_chunk = value.chunk(2, dim=0) - updated_state_dict[diffusers_name.replace("to_kv", "to_k")] = v_chunk[0] - updated_state_dict[diffusers_name.replace("to_kv", "to_v")] = v_chunk[1] + diffusers_name = diffusers_name.replace("0.0.norm1", "0.ln0") + diffusers_name = diffusers_name.replace("0.0.norm2", "0.ln1") + diffusers_name = diffusers_name.replace("1.0.norm1", "1.ln0") + diffusers_name = diffusers_name.replace("1.0.norm2", "1.ln1") + diffusers_name = diffusers_name.replace("2.0.norm1", "2.ln0") + diffusers_name = diffusers_name.replace("2.0.norm2", "2.ln1") + diffusers_name = diffusers_name.replace("3.0.norm1", "3.ln0") + diffusers_name = diffusers_name.replace("3.0.norm2", "3.ln1") + + if "to_kv" in diffusers_name: + parts = diffusers_name.split(".") + parts[2] = "attn" + diffusers_name = ".".join(parts) + v_chunk = value.chunk(2, axis=0) + updated_state_dict[diffusers_name.replace("to_kv", "to_k")] = ms.Parameter( + v_chunk[0], name=diffusers_name.replace("to_kv", "to_k") + ) + updated_state_dict[diffusers_name.replace("to_kv", "to_v")] = ms.Parameter( + v_chunk[1], name=diffusers_name.replace("to_kv", "to_v") + ) + elif "to_q" in diffusers_name: + parts = diffusers_name.split(".") + parts[2] = "attn" + diffusers_name = ".".join(parts) + updated_state_dict[diffusers_name] = value elif "to_out" in diffusers_name: + parts = diffusers_name.split(".") + parts[2] = "attn" + diffusers_name = ".".join(parts) updated_state_dict[diffusers_name.replace("to_out", "to_out.0")] = value else: + diffusers_name = diffusers_name.replace("0.1.0", "0.ff.0") + diffusers_name = diffusers_name.replace("0.1.1", "0.ff.1.net.0.proj") + diffusers_name = diffusers_name.replace("0.1.3", "0.ff.1.net.2") + + diffusers_name = diffusers_name.replace("1.1.0", "1.ff.0") + diffusers_name = diffusers_name.replace("1.1.1", "1.ff.1.net.0.proj") + diffusers_name = diffusers_name.replace("1.1.3", "1.ff.1.net.2") + + diffusers_name = diffusers_name.replace("2.1.0", "2.ff.0") + diffusers_name = diffusers_name.replace("2.1.1", "2.ff.1.net.0.proj") + diffusers_name = diffusers_name.replace("2.1.3", "2.ff.1.net.2") + + diffusers_name = diffusers_name.replace("3.1.0", "3.ff.0") + diffusers_name = diffusers_name.replace("3.1.1", "3.ff.1.net.0.proj") + diffusers_name = diffusers_name.replace("3.1.3", "3.ff.1.net.2") + updated_state_dict[diffusers_name] = value updated_state_dict[diffusers_name] = value _load_param_into_net(image_projection, updated_state_dict) @@ -464,6 +757,7 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts): if cross_attention_dim is None or "motion_modules" in name: attn_processor_class = AttnProcessor attn_procs[name] = attn_processor_class() + else: attn_processor_class = IPAdapterAttnProcessor num_image_text_embeds = [] @@ -474,6 +768,12 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts): elif "proj.3.weight" in state_dict["image_proj"]: # IP-Adapter Full Face num_image_text_embeds += [257] # 256 CLIP tokens + 1 CLS token + elif "perceiver_resampler.proj_in.weight" in state_dict["image_proj"]: + # IP-Adapter Face ID Plus + num_image_text_embeds += [4] + elif "norm.weight" in state_dict["image_proj"]: + # IP-Adapter Face ID + num_image_text_embeds += [4] else: # IP-Adapter Plus num_image_text_embeds += [state_dict["image_proj"]["latents"].shape[1]] @@ -513,96 +813,59 @@ def _load_ip_adapter_weights(self, state_dicts): self.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers) self.config.encoder_hid_dim_type = "ip_image_proj" - self.encoder_hid_dim_type = "ip_image_proj" + self.config["encoder_hid_dim_type"] = "ip_image_proj" # not same with `self.config.encoder_hid_dim_type` + self.encoder_hid_dim_type = "ip_image_proj" # used in UNet2DConditionModel.construct() self.to(dtype=self.dtype) - -class FromOriginalUNetMixin: - """ - Load pretrained UNet model weights saved in the `.ckpt` or `.safetensors` format into a [`StableCascadeUNet`]. - """ - - @classmethod - @validate_hf_hub_args - def from_single_file(cls, pretrained_model_link_or_path, **kwargs): - r""" - Instantiate a [`StableCascadeUNet`] from pretrained StableCascadeUNet weights saved in the original `.ckpt` or - `.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default. - - Parameters: - pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*): - Can be either: - - A link to the `.ckpt` file (for example - `"https://huggingface.co//blob/main/.ckpt"`) on the Hub. - - A path to a *file* containing all pipeline weights. - config: (`dict`, *optional*): - Dictionary containing the configuration of the model: - mindspore_dtype (`str` or `mindspore.dtype`, *optional*): - Override the default `mindspore.dtype` and load the model with another dtype. If `"auto"` is passed, the - dtype is automatically derived from the model's weights. - force_download (`bool`, *optional*, defaults to `False`): - Whether or not to force the (re-)download of the model weights and configuration files, overriding the - cached versions if they exist. - cache_dir (`Union[str, os.PathLike]`, *optional*): - Path to a directory where a downloaded pretrained model configuration is cached if the standard cache - is not used. - resume_download (`bool`, *optional*, defaults to `False`): - Whether or not to resume downloading the model weights and configuration files. If set to `False`, any - incompletely downloaded files are deleted. - proxies (`Dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. - local_files_only (`bool`, *optional*, defaults to `False`): - Whether to only load local model weights and configuration files or not. If set to True, the model - won't be downloaded from the Hub. - token (`str` or *bool*, *optional*): - The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from - `diffusers-cli login` (stored in `~/.huggingface`) is used. - revision (`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier - allowed by Git. - kwargs (remaining dictionary of keyword arguments, *optional*): - Can be used to overwrite load and saveable variables of the model. - - """ - class_name = cls.__name__ - if class_name != "StableCascadeUNet": - raise ValueError("FromOriginalUNetMixin is currently only compatible with StableCascadeUNet") - - config = kwargs.pop("config", None) - resume_download = kwargs.pop("resume_download", False) - force_download = kwargs.pop("force_download", False) - proxies = kwargs.pop("proxies", None) - token = kwargs.pop("token", None) - cache_dir = kwargs.pop("cache_dir", None) - local_files_only = kwargs.pop("local_files_only", None) - revision = kwargs.pop("revision", None) - mindspore_dtype = kwargs.pop("mindspore_dtype", None) - - checkpoint = load_single_file_model_checkpoint( - pretrained_model_link_or_path, - resume_download=resume_download, - force_download=force_download, - proxies=proxies, - token=token, - cache_dir=cache_dir, - local_files_only=local_files_only, - revision=revision, - ) - - if config is None: - config = infer_stable_cascade_single_file_config(checkpoint) - model_config = cls.load_config(**config, **kwargs) - else: - model_config = config - - model = cls.from_config(model_config, **kwargs) - - diffusers_format_checkpoint = convert_stable_cascade_unet_single_file_to_diffusers(checkpoint) - - _load_param_into_net(model, diffusers_format_checkpoint) - - if mindspore_dtype is not None: - model.to(mindspore_dtype) - - return model + def _load_ip_adapter_loras(self, state_dicts): + lora_dicts = {} + for key_id, name in enumerate(self.attn_processors.keys()): + for i, state_dict in enumerate(state_dicts): + if f"{key_id}.to_k_lora.down.weight" in state_dict["ip_adapter"]: + if i not in lora_dicts: + lora_dicts[i] = {} + lora_dicts[i].update( + { + f"unet.{name}.to_k_lora.down.weight": state_dict["ip_adapter"][ + f"{key_id}.to_k_lora.down.weight" + ] + } + ) + lora_dicts[i].update( + { + f"unet.{name}.to_q_lora.down.weight": state_dict["ip_adapter"][ + f"{key_id}.to_q_lora.down.weight" + ] + } + ) + lora_dicts[i].update( + { + f"unet.{name}.to_v_lora.down.weight": state_dict["ip_adapter"][ + f"{key_id}.to_v_lora.down.weight" + ] + } + ) + lora_dicts[i].update( + { + f"unet.{name}.to_out_lora.down.weight": state_dict["ip_adapter"][ + f"{key_id}.to_out_lora.down.weight" + ] + } + ) + lora_dicts[i].update( + {f"unet.{name}.to_k_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_k_lora.up.weight"]} + ) + lora_dicts[i].update( + {f"unet.{name}.to_q_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_q_lora.up.weight"]} + ) + lora_dicts[i].update( + {f"unet.{name}.to_v_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_v_lora.up.weight"]} + ) + lora_dicts[i].update( + { + f"unet.{name}.to_out_lora.up.weight": state_dict["ip_adapter"][ + f"{key_id}.to_out_lora.up.weight" + ] + } + ) + return lora_dicts diff --git a/mindone/diffusers/loaders/unet_loader_utils.py b/mindone/diffusers/loaders/unet_loader_utils.py new file mode 100644 index 0000000000..c8501ed3c4 --- /dev/null +++ b/mindone/diffusers/loaders/unet_loader_utils.py @@ -0,0 +1,160 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +from typing import TYPE_CHECKING, Dict, List, Union + +from ..utils import logging + +if TYPE_CHECKING: + # import here to avoid circular imports + from ..models import UNet2DConditionModel + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def _translate_into_actual_layer_name(name): + """Translate user-friendly name (e.g. 'mid') into actual layer name (e.g. 'mid_block.attentions.0')""" + if name == "mid": + return "mid_block.attentions.0" + + updown, block, attn = name.split(".") + + updown = updown.replace("down", "down_blocks").replace("up", "up_blocks") + block = block.replace("block_", "") + attn = "attentions." + attn + + return ".".join((updown, block, attn)) + + +def _maybe_expand_lora_scales(unet: "UNet2DConditionModel", weight_scales: List[Union[float, Dict]], default_scale=1.0): + blocks_with_transformer = { + "down": [i for i, block in enumerate(unet.down_blocks) if hasattr(block, "attentions")], + "up": [i for i, block in enumerate(unet.up_blocks) if hasattr(block, "attentions")], + } + transformer_per_block = {"down": unet.config.layers_per_block, "up": unet.config.layers_per_block + 1} + + expanded_weight_scales = [ + _maybe_expand_lora_scales_for_one_adapter( + weight_for_adapter, + blocks_with_transformer, + transformer_per_block, + unet.parameters_dict(), + default_scale=default_scale, + ) + for weight_for_adapter in weight_scales + ] + + return expanded_weight_scales + + +def _maybe_expand_lora_scales_for_one_adapter( + scales: Union[float, Dict], + blocks_with_transformer: Dict[str, int], + transformer_per_block: Dict[str, int], + state_dict: None, + default_scale: float = 1.0, +): + """ + Expands the inputs into a more granular dictionary. See the example below for more details. + + Parameters: + scales (`Union[float, Dict]`): + Scales dict to expand. + blocks_with_transformer (`Dict[str, int]`): + Dict with keys 'up' and 'down', showing which blocks have transformer layers + transformer_per_block (`Dict[str, int]`): + Dict with keys 'up' and 'down', showing how many transformer layers each block has + + E.g. turns + ```python + scales = {"down": 2, "mid": 3, "up": {"block_0": 4, "block_1": [5, 6, 7]}} + blocks_with_transformer = {"down": [1, 2], "up": [0, 1]} + transformer_per_block = {"down": 2, "up": 3} + ``` + into + ```python + { + "down.block_1.0": 2, + "down.block_1.1": 2, + "down.block_2.0": 2, + "down.block_2.1": 2, + "mid": 3, + "up.block_0.0": 4, + "up.block_0.1": 4, + "up.block_0.2": 4, + "up.block_1.0": 5, + "up.block_1.1": 6, + "up.block_1.2": 7, + } + ``` + """ + if sorted(blocks_with_transformer.keys()) != ["down", "up"]: + raise ValueError("blocks_with_transformer needs to be a dict with keys `'down' and `'up'`") + + if sorted(transformer_per_block.keys()) != ["down", "up"]: + raise ValueError("transformer_per_block needs to be a dict with keys `'down' and `'up'`") + + if not isinstance(scales, dict): + # don't expand if scales is a single number + return scales + + scales = copy.deepcopy(scales) + + if "mid" not in scales: + scales["mid"] = default_scale + elif isinstance(scales["mid"], list): + if len(scales["mid"]) == 1: + scales["mid"] = scales["mid"][0] + else: + raise ValueError(f"Expected 1 scales for mid, got {len(scales['mid'])}.") + + for updown in ["up", "down"]: + if updown not in scales: + scales[updown] = default_scale + + # eg {"down": 1} to {"down": {"block_1": 1, "block_2": 1}}} + if not isinstance(scales[updown], dict): + scales[updown] = {f"block_{i}": copy.deepcopy(scales[updown]) for i in blocks_with_transformer[updown]} + + # eg {"down": {"block_1": 1}} to {"down": {"block_1": [1, 1]}} + for i in blocks_with_transformer[updown]: + block = f"block_{i}" + # set not assigned blocks to default scale + if block not in scales[updown]: + scales[updown][block] = default_scale + if not isinstance(scales[updown][block], list): + scales[updown][block] = [scales[updown][block] for _ in range(transformer_per_block[updown])] + elif len(scales[updown][block]) == 1: + # a list specifying scale to each masked IP input + scales[updown][block] = scales[updown][block] * transformer_per_block[updown] + elif len(scales[updown][block]) != transformer_per_block[updown]: + raise ValueError( + f"Expected {transformer_per_block[updown]} scales for {updown}.{block}, got {len(scales[updown][block])}." + ) + + # eg {"down": "block_1": [1, 1]}} to {"down.block_1.0": 1, "down.block_1.1": 1} + for i in blocks_with_transformer[updown]: + block = f"block_{i}" + for tf_idx, value in enumerate(scales[updown][block]): + scales[f"{updown}.{block}.{tf_idx}"] = value + + del scales[updown] + + for layer in scales.keys(): + if not any(_translate_into_actual_layer_name(layer) in module for module in state_dict.keys()): + raise ValueError( + f"Can't set lora scale for layer {layer}. It either doesn't exist in this unet or it has no attentions." + ) + + return {_translate_into_actual_layer_name(name): weight for name, weight in scales.items()} diff --git a/mindone/diffusers/models/__init__.py b/mindone/diffusers/models/__init__.py index 84f00e2df4..e030b857d3 100644 --- a/mindone/diffusers/models/__init__.py +++ b/mindone/diffusers/models/__init__.py @@ -23,15 +23,22 @@ "autoencoders.autoencoder_kl_temporal_decoder": ["AutoencoderKLTemporalDecoder"], "autoencoders.autoencoder_tiny": ["AutoencoderTiny"], "autoencoders.consistency_decoder_vae": ["ConsistencyDecoderVAE"], + "autoencoders.vq_model": ["VQModel"], "controlnet": ["ControlNetModel"], + "controlnet_sd3": ["SD3ControlNetModel", "SD3MultiControlNetModel"], + "controlnet_xs": ["ControlNetXSAdapter", "UNetControlNetXSModel"], "dual_transformer_2d": ["DualTransformer2DModel"], "embeddings": ["ImageProjection"], "modeling_utils": ["ModelMixin"], + "transformers.dit_transformer_2d": ["DiTTransformer2DModel"], + "transformers.dual_transformer_2d": ["DualTransformer2DModel"], + "transformers.hunyuan_transformer_2d": ["HunyuanDiT2DModel"], + "transformers.pixart_transformer_2d": ["PixArtTransformer2DModel"], "transformers.prior_transformer": ["PriorTransformer"], "transformers.t5_film_transformer": ["T5FilmDecoder"], "transformers.transformer_2d": ["Transformer2DModel"], - "transformers.transformer_temporal": ["TransformerTemporalModel"], "transformers.transformer_sd3": ["SD3Transformer2DModel"], + "transformers.transformer_temporal": ["TransformerTemporalModel"], "unets.unet_1d": ["UNet1DModel"], "unets.unet_2d": ["UNet2DModel"], "unets.unet_2d_condition": ["UNet2DConditionModel"], @@ -42,7 +49,6 @@ "unets.unet_stable_cascade": ["StableCascadeUNet"], "unets.unet_spatio_temporal_condition": ["UNetSpatioTemporalConditionModel"], "unets.uvit_2d": ["UVit2DModel"], - "vq_model": ["VQModel"], } if TYPE_CHECKING: @@ -53,12 +59,18 @@ AutoencoderKLTemporalDecoder, AutoencoderTiny, ConsistencyDecoderVAE, + VQModel, ) from .controlnet import ControlNetModel + from .controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel + from .controlnet_xs import ControlNetXSAdapter, UNetControlNetXSModel from .embeddings import ImageProjection from .modeling_utils import ModelMixin from .transformers import ( + DiTTransformer2DModel, DualTransformer2DModel, + HunyuanDiT2DModel, + PixArtTransformer2DModel, PriorTransformer, SD3Transformer2DModel, T5FilmDecoder, @@ -78,7 +90,6 @@ UNetSpatioTemporalConditionModel, UVit2DModel, ) - from .vq_model import VQModel else: import sys diff --git a/mindone/diffusers/models/adapter.py b/mindone/diffusers/models/adapter.py index c2583ca4a5..e55cdc551e 100644 --- a/mindone/diffusers/models/adapter.py +++ b/mindone/diffusers/models/adapter.py @@ -90,6 +90,7 @@ def construct(self, xs: ms.Tensor, adapter_weights: Optional[List[float]] = None accume_state = None for x, w, adapter in zip(xs, adapter_weights, self.adapters): features = adapter(x) + w = w.to(x.dtype) # cast manually as torch do the same automatically for scaler tensors if accume_state is None: accume_state = features for i in range(len(accume_state)): diff --git a/mindone/diffusers/models/attention.py b/mindone/diffusers/models/attention.py index d5b9d7fc9b..0f164e351c 100644 --- a/mindone/diffusers/models/attention.py +++ b/mindone/diffusers/models/attention.py @@ -765,7 +765,6 @@ def __init__( if inner_dim is None: inner_dim = int(dim * mult) dim_out = dim_out if dim_out is not None else dim - linear_cls = nn.Dense if activation_fn == "gelu": act_fn = GELU(dim, inner_dim, bias=bias) @@ -782,7 +781,7 @@ def __init__( # project dropout net.append(nn.Dropout(p=dropout)) # project out - net.append(linear_cls(inner_dim, dim_out, has_bias=bias)) + net.append(nn.Dense(inner_dim, dim_out, has_bias=bias)) # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout if final_dropout: net.append(nn.Dropout(p=dropout)) diff --git a/mindone/diffusers/models/attention_processor.py b/mindone/diffusers/models/attention_processor.py index 25828ef3fe..c08326fed5 100644 --- a/mindone/diffusers/models/attention_processor.py +++ b/mindone/diffusers/models/attention_processor.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Optional, Union +import math +from typing import Callable, List, Optional, Union import mindspore as ms from mindspore import nn, ops @@ -85,6 +86,7 @@ def __init__( upcast_softmax: bool = False, cross_attention_norm: Optional[str] = None, cross_attention_norm_num_groups: int = 32, + qk_norm: Optional[str] = None, added_kv_proj_dim: Optional[int] = None, norm_num_groups: Optional[int] = None, spatial_norm_dim: Optional[int] = None, @@ -122,6 +124,8 @@ def __init__( self.scale_qk = scale_qk self.scale = dim_head**-0.5 if self.scale_qk else 1.0 + # use `scale_sqrt` for ops.baddbmm to get same outputs with torch.baddbmm in fp16 + self.scale_sqrt = float(math.sqrt(self.scale)) self.heads = out_dim // dim_head if out_dim is not None else heads # for slice_size > 0 the attention score computation @@ -147,6 +151,15 @@ def __init__( else: self.spatial_norm = None + if qk_norm is None: + self.norm_q = None + self.norm_k = None + elif qk_norm == "layer_norm": + self.norm_q = LayerNorm(dim_head, eps=eps) + self.norm_k = LayerNorm(dim_head, eps=eps) + else: + raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None or 'layer_norm'") + self.cross_attention_norm = cross_attention_norm if cross_attention_norm is None: self.norm_cross = None @@ -171,26 +184,23 @@ def __init__( f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'" ) - linear_cls = nn.Dense - - self.linear_cls = linear_cls - self.to_q = linear_cls(query_dim, self.inner_dim, has_bias=bias) + self.to_q = nn.Dense(query_dim, self.inner_dim, has_bias=bias) if not self.only_cross_attention: # only relevant for the `AddedKVProcessor` classes - self.to_k = linear_cls(self.cross_attention_dim, self.inner_dim, has_bias=bias) - self.to_v = linear_cls(self.cross_attention_dim, self.inner_dim, has_bias=bias) + self.to_k = nn.Dense(self.cross_attention_dim, self.inner_dim, has_bias=bias) + self.to_v = nn.Dense(self.cross_attention_dim, self.inner_dim, has_bias=bias) else: self.to_k = None self.to_v = None if self.added_kv_proj_dim is not None: - self.add_k_proj = linear_cls(added_kv_proj_dim, self.inner_dim) - self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim) + self.add_k_proj = nn.Dense(added_kv_proj_dim, self.inner_dim) + self.add_v_proj = nn.Dense(added_kv_proj_dim, self.inner_dim) if self.context_pre_only is not None: self.add_q_proj = nn.Dense(added_kv_proj_dim, self.inner_dim) - self.to_out = nn.CellList([linear_cls(self.inner_dim, self.out_dim, has_bias=out_bias), nn.Dropout(p=dropout)]) + self.to_out = nn.CellList([nn.Dense(self.inner_dim, self.out_dim, has_bias=out_bias), nn.Dropout(p=dropout)]) if self.context_pre_only is not None and not self.context_pre_only: self.to_add_out = nn.Dense(self.inner_dim, self.out_dim, has_bias=out_bias) @@ -424,17 +434,17 @@ def get_attention_scores(self, query: ms.Tensor, key: ms.Tensor, attention_mask: key = key.float() if attention_mask is None: - attention_scores = self.scale * ops.bmm( - query, - key.swapaxes(-1, -2), + attention_scores = ops.bmm( + query * self.scale_sqrt, + key.swapaxes(-1, -2) * self.scale_sqrt, ) else: attention_scores = ops.baddbmm( attention_mask.to(query.dtype), - query, - key.swapaxes(-1, -2), + query * self.scale_sqrt, + key.swapaxes(-1, -2) * self.scale_sqrt, beta=1, - alpha=self.scale, + alpha=1, ) if self.upcast_softmax: @@ -527,7 +537,7 @@ def fuse_projections(self, fuse=True): out_features = concatenated_weights.shape[0] # create a new single projection layer and copy over the weights. - self.to_qkv = self.linear_cls(in_features, out_features, has_bias=self.use_bias, dtype=dtype) + self.to_qkv = nn.Dense(in_features, out_features, has_bias=self.use_bias, dtype=dtype) self.to_qkv.weight.set_data(concatenated_weights) if self.use_bias: concatenated_bias = ops.cat([self.to_q.bias, self.to_k.bias, self.to_v.bias]) @@ -538,7 +548,7 @@ def fuse_projections(self, fuse=True): in_features = concatenated_weights.shape[1] out_features = concatenated_weights.shape[0] - self.to_kv = self.linear_cls(in_features, out_features, has_bias=self.use_bias, dtype=dtype) + self.to_kv = nn.Dense(in_features, out_features, has_bias=self.use_bias, dtype=dtype) self.to_kv.weight.set_data(concatenated_weights) if self.use_bias: concatenated_bias = ops.cat([self.to_k.bias, self.to_v.bias]) @@ -1044,6 +1054,104 @@ def __call__( return hidden_states +@ms.jit_class +class HunyuanAttnProcessor: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is + used in the HunyuanDiT model. It applies a s normalization layer and rotary embedding on query and key vector. + """ + + def __init__(self) -> None: + # move importing from __call__ to __init__ as it is not supported in construct() + from .embeddings import apply_rotary_emb + + self.apply_rotary_emb = apply_rotary_emb + + def __call__( + self, + attn: Attention, + hidden_states: ms.Tensor, + encoder_hidden_states: Optional[ms.Tensor] = None, + attention_mask: Optional[ms.Tensor] = None, + temb: Optional[ms.Tensor] = None, + image_rotary_emb: Optional[ms.Tensor] = None, + ) -> ms.Tensor: + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + batch_size, channel, height, width = (None,) * 4 + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).swapaxes(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.swapaxes(1, 2)).swapaxes(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = attn.head_to_batch_dim(query, out_dim=4) + key = attn.head_to_batch_dim(key, out_dim=4) + value = attn.head_to_batch_dim(value, out_dim=4) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply RoPE if needed + if image_rotary_emb is not None: + query = self.apply_rotary_emb(query, image_rotary_emb) + if not attn.is_cross_attention: + key = self.apply_rotary_emb(key, image_rotary_emb) + + # # the output of sdp = (batch, num_heads, seq_len, head_dim) + # # TODO: add support for attn.scale when we move to Torch 2.1 + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = ops.bmm(attention_probs, value) + + hidden_states = hidden_states.swapaxes(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.swapaxes(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + class SpatialNorm(nn.Cell): """ Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002. @@ -1077,7 +1185,7 @@ def construct(self, f: ms.Tensor, zq: ms.Tensor) -> ms.Tensor: class IPAdapterAttnProcessor(nn.Cell): r""" - Attention processor for Multiple IP-Adapater. + Attention processor for Multiple IP-Adapters. Args: hidden_size (`int`): @@ -1172,15 +1280,33 @@ def __call__( hidden_states = attn.batch_to_head_dim(hidden_states) if ip_adapter_masks is not None: - if not isinstance(ip_adapter_masks, ms.Tensor) or ip_adapter_masks.ndim != 4: + if not isinstance(ip_adapter_masks, List): + # for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width] + ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1)) + if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)): raise ValueError( - " ip_adapter_mask should be a tensor with shape [num_ip_adapter, 1, height, width]." - " Please use `IPAdapterMaskProcessor` to preprocess your mask" - ) - if len(ip_adapter_masks) != len(self.scale): - raise ValueError( - f"Number of ip_adapter_masks ({len(ip_adapter_masks)}) must match number of IP-Adapters ({len(self.scale)})" + f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match " + f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states " + f"({len(ip_hidden_states)})" ) + else: + for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)): + if not isinstance(mask, ms.Tensor) or mask.ndim != 4: + raise ValueError( + "Each element of the ip_adapter_masks array should be a tensor with shape " + "[1, num_images_for_ip_adapter, height, width]." + " Please use `IPAdapterMaskProcessor` to preprocess your mask" + ) + if mask.shape[1] != ip_state.shape[1]: + raise ValueError( + f"Number of masks ({mask.shape[1]}) does not match " + f"number of ip images ({ip_state.shape[1]}) at index {index}" + ) + if isinstance(scale, list) and not len(scale) == mask.shape[1]: + raise ValueError( + f"Number of masks ({mask.shape[1]}) does not match " + f"number of scales ({len(scale)}) at index {index}" + ) else: ip_adapter_masks = [None] * len(self.scale) @@ -1188,26 +1314,51 @@ def __call__( for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip( ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks ): - ip_key = to_k_ip(current_ip_hidden_states) - ip_value = to_v_ip(current_ip_hidden_states) - - ip_key = attn.head_to_batch_dim(ip_key) - ip_value = attn.head_to_batch_dim(ip_value) - - ip_attention_probs = attn.get_attention_scores(query, ip_key, None) - current_ip_hidden_states = ops.bmm(ip_attention_probs, ip_value) - current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states) - - if mask is not None: - mask_downsample = IPAdapterMaskProcessor.downsample( - mask, batch_size, current_ip_hidden_states.shape[1], current_ip_hidden_states.shape[2] - ) + skip = False + if isinstance(scale, list): + if all(s == 0 for s in scale): + skip = True + elif scale == 0: + skip = True + if not skip: + if mask is not None: + if not isinstance(scale, list): + scale = [scale] * mask.shape[1] + + current_num_images = mask.shape[1] + for i in range(current_num_images): + ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :]) + ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :]) + + ip_key = attn.head_to_batch_dim(ip_key) + ip_value = attn.head_to_batch_dim(ip_value) + + ip_attention_probs = attn.get_attention_scores(query, ip_key, None) + _current_ip_hidden_states = ops.bmm(ip_attention_probs, ip_value) + _current_ip_hidden_states = attn.batch_to_head_dim(_current_ip_hidden_states) + + mask_downsample = IPAdapterMaskProcessor.downsample( + mask[:, i, :, :], + batch_size, + _current_ip_hidden_states.shape[1], + _current_ip_hidden_states.shape[2], + ) + + mask_downsample = mask_downsample.to(dtype=query.dtype) + + hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample) + else: + ip_key = to_k_ip(current_ip_hidden_states) + ip_value = to_v_ip(current_ip_hidden_states) - mask_downsample = mask_downsample.to(dtype=query.dtype) + ip_key = attn.head_to_batch_dim(ip_key) + ip_value = attn.head_to_batch_dim(ip_value) - current_ip_hidden_states = current_ip_hidden_states * mask_downsample + ip_attention_probs = attn.get_attention_scores(query, ip_key, None) + current_ip_hidden_states = ops.bmm(ip_attention_probs, ip_value) + current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states) - hidden_states = hidden_states + scale * current_ip_hidden_states + hidden_states = hidden_states + scale * current_ip_hidden_states # linear proj hidden_states = attn.to_out[0](hidden_states) @@ -1240,4 +1391,5 @@ def __call__( CustomDiffusionAttnProcessor, JointAttnProcessor, FusedJointAttnProcessor, + HunyuanAttnProcessor, ] diff --git a/mindone/diffusers/models/autoencoders/__init__.py b/mindone/diffusers/models/autoencoders/__init__.py index 201a40ff17..5c47748d62 100644 --- a/mindone/diffusers/models/autoencoders/__init__.py +++ b/mindone/diffusers/models/autoencoders/__init__.py @@ -3,3 +3,4 @@ from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder from .autoencoder_tiny import AutoencoderTiny from .consistency_decoder_vae import ConsistencyDecoderVAE +from .vq_model import VQModel diff --git a/mindone/diffusers/models/autoencoders/autoencoder_asym_kl.py b/mindone/diffusers/models/autoencoders/autoencoder_asym_kl.py index ae2e496c45..121e23a64a 100644 --- a/mindone/diffusers/models/autoencoders/autoencoder_asym_kl.py +++ b/mindone/diffusers/models/autoencoders/autoencoder_asym_kl.py @@ -158,6 +158,7 @@ def construct( mask: Optional[ms.Tensor] = None, sample_posterior: bool = False, return_dict: bool = False, + generator: Optional[np.random.Generator] = None, ) -> Union[DecoderOutput, Tuple[ms.Tensor]]: r""" Args: @@ -175,7 +176,7 @@ def construct( else: z = self.diag_gauss_dist.mode(latent) - dec = self.decode(z, sample, mask)[0] + dec = self.decode(z, generator, sample, mask)[0] if not return_dict: return (dec,) diff --git a/mindone/diffusers/models/autoencoders/autoencoder_kl.py b/mindone/diffusers/models/autoencoders/autoencoder_kl.py index 3dd38cb224..9ac1948439 100644 --- a/mindone/diffusers/models/autoencoders/autoencoder_kl.py +++ b/mindone/diffusers/models/autoencoders/autoencoder_kl.py @@ -13,18 +13,20 @@ # limitations under the License. from typing import Dict, Optional, Tuple, Union +import numpy as np + import mindspore as ms from mindspore import nn from ...configuration_utils import ConfigMixin, register_to_config -from ...loaders import FromOriginalVAEMixin +from ...loaders import FromOriginalModelMixin from ..attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor from ..modeling_outputs import AutoencoderKLOutput from ..modeling_utils import ModelMixin from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder -class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin): +class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin): r""" A VAE model with KL loss for encoding images into latents and decoding latent representations into images. @@ -57,6 +59,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin): """ _supports_gradient_checkpointing = True + _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D"] @register_to_config def __init__( @@ -240,7 +243,7 @@ def _decode(self, z: ms.Tensor, return_dict: bool = False) -> Union[DecoderOutpu return DecoderOutput(sample=dec) - def decode(self, z: ms.Tensor, return_dict: bool = False) -> Union[DecoderOutput, Tuple[ms.Tensor]]: + def decode(self, z: ms.Tensor, return_dict: bool = False, generator=None) -> Union[DecoderOutput, Tuple[ms.Tensor]]: """ Decode a batch of images. @@ -267,6 +270,7 @@ def construct( sample: ms.Tensor, sample_posterior: bool = False, return_dict: bool = False, + generator: Optional[np.random.Generator] = None, ) -> Union[DecoderOutput, Tuple[ms.Tensor]]: r""" Args: @@ -279,7 +283,7 @@ def construct( x = sample latent = self.encode(x)[0] if sample_posterior: - z = self.diag_gauss_dist.sample(latent) + z = self.diag_gauss_dist.sample(latent, generator=generator) else: z = self.diag_gauss_dist.mode(latent) dec = self.decode(z)[0] diff --git a/mindone/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py b/mindone/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py index 02cb78a257..29f3a32f3c 100644 --- a/mindone/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +++ b/mindone/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py @@ -11,7 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Tuple, Union +from typing import Dict, Optional, Tuple, Union + +import numpy as np import mindspore as ms from mindspore import nn, ops @@ -282,11 +284,13 @@ def encode( Args: x (`ms.Tensor`): Input batch of images. return_dict (`bool`, *optional*, defaults to `True`): - Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + Whether to return a [`~models.autoencoders.autoencoder_kl.AutoencoderKLOutput`] instead of a plain + tuple. Returns: The latent representations of the encoded images. If `return_dict` is True, a - [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + [`~models.autoencoders.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is + returned. """ h = self.encoder(x) moments = self.quant_conv(h) @@ -330,6 +334,7 @@ def construct( sample: ms.Tensor, sample_posterior: bool = False, return_dict: bool = False, + generator: Optional[np.random.Generator] = None, num_frames: int = 1, ) -> Union[DecoderOutput, ms.Tensor]: r""" @@ -343,7 +348,7 @@ def construct( x = sample latent = self.encode(x)[0] if sample_posterior: - z = self.diag_gauss_dist.sample(latent) + z = self.diag_gauss_dist.sample(latent, generator=generator) else: z = self.diag_gauss_dist.mode(latent) diff --git a/mindone/diffusers/models/autoencoders/autoencoder_tiny.py b/mindone/diffusers/models/autoencoders/autoencoder_tiny.py index 09edf20256..69e9f845a1 100644 --- a/mindone/diffusers/models/autoencoders/autoencoder_tiny.py +++ b/mindone/diffusers/models/autoencoders/autoencoder_tiny.py @@ -14,7 +14,9 @@ from dataclasses import dataclass -from typing import Tuple, Union +from typing import Optional, Tuple, Union + +import numpy as np import mindspore as ms from mindspore import ops @@ -102,6 +104,7 @@ def __init__( encoder_block_out_channels: Tuple[int, ...] = (64, 64, 64, 64), decoder_block_out_channels: Tuple[int, ...] = (64, 64, 64, 64), act_fn: str = "relu", + upsample_fn: str = "nearest", latent_channels: int = 4, upsampling_scaling_factor: int = 2, num_encoder_blocks: Tuple[int, ...] = (1, 3, 3, 3), @@ -133,6 +136,7 @@ def __init__( block_out_channels=decoder_block_out_channels, upsampling_scaling_factor=upsampling_scaling_factor, act_fn=act_fn, + upsample_fn=upsample_fn, ) self.latent_magnitude = latent_magnitude @@ -171,7 +175,9 @@ def encode(self, x: ms.Tensor, return_dict: bool = False) -> Union[AutoencoderTi return AutoencoderTinyOutput(latents=output) - def decode(self, x: ms.Tensor, return_dict: bool = False) -> Union[DecoderOutput, Tuple[ms.Tensor]]: + def decode( + self, x: ms.Tensor, generator: Optional[np.random.Generator] = None, return_dict: bool = False + ) -> Union[DecoderOutput, Tuple[ms.Tensor]]: output = self.decoder(x) if not return_dict: diff --git a/mindone/diffusers/models/autoencoders/consistency_decoder_vae.py b/mindone/diffusers/models/autoencoders/consistency_decoder_vae.py index a5f0ee88d5..f545b6a45f 100644 --- a/mindone/diffusers/models/autoencoders/consistency_decoder_vae.py +++ b/mindone/diffusers/models/autoencoders/consistency_decoder_vae.py @@ -57,7 +57,8 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin): ... "runwayml/stable-diffusion-v1-5", vae=vae, mindspore_dtype=mindspore.float16 ... ) - >>> pipe("horse").images + >>> image = pipe("horse")[0][0] + >>> image ``` """ @@ -66,6 +67,7 @@ def __init__( self, scaling_factor: float = 0.18215, latent_channels: int = 4, + sample_size: int = 32, encoder_act_fn: str = "silu", encoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512), encoder_double_z: bool = True, @@ -142,6 +144,16 @@ def __init__( self.use_slicing = False self.use_tiling = False + # only relevant if vae tiling is enabled + self.tile_sample_min_size = self.config.sample_size + sample_size = ( + self.config.sample_size[0] + if isinstance(self.config.sample_size, (list, tuple)) + else self.config.sample_size + ) + self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1))) + self.tile_overlap_factor = 0.25 + @property # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: # type: ignore @@ -222,13 +234,13 @@ def encode( Args: x (`ms.Tensor`): Input batch of images. return_dict (`bool`, *optional*, defaults to `True`): - Whether to return a [`~models.consistecy_decoder_vae.ConsistencyDecoderOoutput`] instead of a plain - tuple. + Whether to return a [`~models.autoencoders.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] + instead of a plain tuple. Returns: The latent representations of the encoded images. If `return_dict` is True, a - [`~models.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] is returned, otherwise a plain `tuple` - is returned. + [`~models.autoencoders.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] is returned, otherwise a + plain `tuple` is returned. """ h = self.encoder(x) @@ -246,6 +258,19 @@ def decode( return_dict: bool = False, num_inference_steps: int = 2, ) -> Union[DecoderOutput, Tuple[ms.Tensor]]: + """ + Decodes the input latent vector `z` using the consistency decoder VAE model. + + Args: + z (ms.Tensor): The input latent vector. + generator (Optional[np.random.Generator]): The random number generator. Default is None. + return_dict (bool): Whether to return the output as a dictionary. Default is True. + num_inference_steps (int): The number of inference steps. Default is 2. + + Returns: + Union[DecoderOutput, Tuple[ms.Tensor]]: The decoded output. + + """ z = ((z * self.config["scaling_factor"] - self.means) / self.stds).to(z.dtype) scale_factor = 2 ** (len(self.config["block_out_channels"]) - 1) diff --git a/mindone/diffusers/models/autoencoders/vae.py b/mindone/diffusers/models/autoencoders/vae.py index c8141a3e88..1980f55239 100644 --- a/mindone/diffusers/models/autoencoders/vae.py +++ b/mindone/diffusers/models/autoencoders/vae.py @@ -39,6 +39,7 @@ class DecoderOutput(BaseOutput): """ sample: ms.Tensor + commit_loss: Optional[ms.Tensor] = None class Encoder(nn.Cell): @@ -444,7 +445,6 @@ def __init__( has_bias=True, ) - self.mid_block = None self.up_blocks = [] temb_channels = in_channels if norm_type == "spatial" else None @@ -832,6 +832,7 @@ def __init__( block_out_channels: Tuple[int, ...], upsampling_scaling_factor: int, act_fn: str, + upsample_fn: str, ): super().__init__() @@ -848,7 +849,7 @@ def __init__( layers.append(AutoencoderTinyBlock(num_channels, num_channels, act_fn)) if not is_final_block: - layers.append(DecoderTinyUpsample(scale_factor=upsampling_scaling_factor)) + layers.append(DecoderTinyUpsample(scale_factor=upsampling_scaling_factor, mode=upsample_fn)) conv_out_channel = num_channels if not is_final_block else out_channels layers.append( diff --git a/mindone/diffusers/models/autoencoders/vq_model.py b/mindone/diffusers/models/autoencoders/vq_model.py new file mode 100644 index 0000000000..2845e3c6e6 --- /dev/null +++ b/mindone/diffusers/models/autoencoders/vq_model.py @@ -0,0 +1,176 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import mindspore as ms +from mindspore import nn, ops + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import BaseOutput +from ..autoencoders.vae import Decoder, DecoderOutput, Encoder, VectorQuantizer +from ..modeling_utils import ModelMixin + + +@dataclass +class VQEncoderOutput(BaseOutput): + """ + Output of VQModel encoding method. + + Args: + latents (`ms.Tensor` of shape `(batch_size, num_channels, height, width)`): + The encoded output sample from the last layer of the model. + """ + + latents: ms.Tensor + + +class VQModel(ModelMixin, ConfigMixin): + r""" + A VQ-VAE model for decoding latent representations. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + in_channels (int, *optional*, defaults to 3): Number of channels in the input image. + out_channels (int, *optional*, defaults to 3): Number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`): + Tuple of downsample block types. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`): + Tuple of upsample block types. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`): + Tuple of block output channels. + layers_per_block (`int`, *optional*, defaults to `1`): Number of layers per block. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + latent_channels (`int`, *optional*, defaults to `3`): Number of channels in the latent space. + sample_size (`int`, *optional*, defaults to `32`): Sample input size. + num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE. + norm_num_groups (`int`, *optional*, defaults to `32`): Number of groups for normalization layers. + vq_embed_dim (`int`, *optional*): Hidden dim of codebook vectors in the VQ-VAE. + scaling_factor (`float`, *optional*, defaults to `0.18215`): + The component-wise standard deviation of the trained latent space computed using the first batch of the + training set. This is used to scale the latent space to have unit variance when training the diffusion + model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the + diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 + / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image + Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. + norm_type (`str`, *optional*, defaults to `"group"`): + Type of normalization layer to use. Can be one of `"group"` or `"spatial"`. + """ + + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",), + up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",), + block_out_channels: Tuple[int, ...] = (64,), + layers_per_block: int = 1, + act_fn: str = "silu", + latent_channels: int = 3, + sample_size: int = 32, + num_vq_embeddings: int = 256, + norm_num_groups: int = 32, + vq_embed_dim: Optional[int] = None, + scaling_factor: float = 0.18215, + norm_type: str = "group", # group, spatial + mid_block_add_attention=True, + lookup_from_codebook=False, + force_upcast=False, + ): + super().__init__() + + # pass init params to Encoder + self.encoder = Encoder( + in_channels=in_channels, + out_channels=latent_channels, + down_block_types=down_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + norm_num_groups=norm_num_groups, + double_z=False, + mid_block_add_attention=mid_block_add_attention, + ) + + vq_embed_dim = vq_embed_dim if vq_embed_dim is not None else latent_channels + + self.quant_conv = nn.Conv2d(latent_channels, vq_embed_dim, 1, has_bias=True) + self.quantize = VectorQuantizer(num_vq_embeddings, vq_embed_dim, beta=0.25, remap=None, sane_index_shape=False) + self.post_quant_conv = nn.Conv2d(vq_embed_dim, latent_channels, 1, has_bias=True) + + # pass init params to Decoder + self.decoder = Decoder( + in_channels=latent_channels, + out_channels=out_channels, + up_block_types=up_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + norm_num_groups=norm_num_groups, + norm_type=norm_type, + mid_block_add_attention=mid_block_add_attention, + ) + + def encode(self, x: ms.Tensor, return_dict: bool = False): + h = self.encoder(x) + h = self.quant_conv(h) + + if not return_dict: + return (h,) + + return VQEncoderOutput(latents=h) + + def decode(self, h: ms.Tensor, force_not_quantize: bool = False, return_dict: bool = False, shape=None): + # also go through quantization layer + if not force_not_quantize: + quant, commit_loss, _ = self.quantize(h) + elif self.config["lookup_from_codebook"]: + quant = self.quantize.get_codebook_entry(h, shape) + commit_loss = ops.zeros((h.shape[0],), dtype=h.dtype) + else: + quant = h + commit_loss = ops.zeros((h.shape[0],), dtype=h.dtype) + quant2 = self.post_quant_conv(quant) + dec = self.decoder(quant2, quant if self.config["norm_type"] == "spatial" else None) + + if not return_dict: + return dec, commit_loss + + return DecoderOutput(sample=dec, commit_loss=commit_loss) + + def construct(self, sample: ms.Tensor, return_dict: bool = False) -> Union[DecoderOutput, Tuple[ms.Tensor, ...]]: + r""" + The [`VQModel`] forward method. + + Args: + sample (`ms.Tensor`): Input sample. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.vq_model.VQEncoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vq_model.VQEncoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vq_model.VQEncoderOutput`] is returned, otherwise a plain `tuple` + is returned. + """ + + h = self.encode(sample)[0] + dec = self.decode(h) + + if not return_dict: + return dec # (dec.sample, dec.commit_loss) + + return DecoderOutput(sample=dec[0], commit_loss=dec[1]) diff --git a/mindone/diffusers/models/controlnet.py b/mindone/diffusers/models/controlnet.py index 61e550484d..770718f768 100644 --- a/mindone/diffusers/models/controlnet.py +++ b/mindone/diffusers/models/controlnet.py @@ -18,7 +18,7 @@ from mindspore import nn, ops from ..configuration_utils import ConfigMixin, register_to_config -from ..loaders import FromOriginalControlNetMixin +from ..loaders import FromOriginalModelMixin from ..utils import BaseOutput, logging from .attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor from .embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps @@ -113,7 +113,7 @@ def construct(self, conditioning): return embedding -class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlNetMixin): +class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin): """ A ControlNet model. @@ -532,6 +532,9 @@ def from_unet( controlnet.class_embedding, unet.class_embedding.parameters_dict(), strict_load=True ) + if hasattr(controlnet, "add_embedding"): + ms.load_param_into_net(controlnet.add_embedding, unet.add_embedding.parameters_dict()) + ms.load_param_into_net(controlnet.down_blocks, unet.down_blocks.parameters_dict(), strict_load=True) ms.load_param_into_net(controlnet.mid_block, unet.mid_block.parameters_dict(), strict_load=True) @@ -803,7 +806,10 @@ def construct( if guess_mode and not self.config["global_pool_conditions"]: scales = ops.logspace(-1.0, 0.0, len(down_block_res_samples) + 1) # 0.1 to 1.0 scales = scales * conditioning_scale - down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)] + # Cast scale to sample.dtype manually as torch do the same automatically for scaler tensors + down_block_res_samples = [ + sample * scale.to(sample.dtype) for sample, scale in zip(down_block_res_samples, scales) + ] mid_block_res_sample = mid_block_res_sample * scales[-1] # last one else: down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples] @@ -824,7 +830,8 @@ def construct( def zero_module(module: nn.Cell): - raise RuntimeWarning( - "Method 'zero_module' does nothing because changing parameter data after initiating will " - "make parameter.set_dtype() invalid. Use arguments like 'weight_init' in instantiation instead" + logger.warning( + "Method 'zero_module' does nothing because changing parameter data after initiating will make " + "parameter.set_dtype() invalid sometimes. Use arguments like 'weight_init' in instantiation instead" ) + return module diff --git a/mindone/diffusers/models/controlnet_sd3.py b/mindone/diffusers/models/controlnet_sd3.py new file mode 100644 index 0000000000..18857b45aa --- /dev/null +++ b/mindone/diffusers/models/controlnet_sd3.py @@ -0,0 +1,358 @@ +# Copyright 2024 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import mindspore as ms +from mindspore import nn + +from ..configuration_utils import ConfigMixin, register_to_config +from ..loaders import PeftAdapterMixin +from ..models.attention import JointTransformerBlock +from ..models.attention_processor import AttentionProcessor +from ..models.modeling_utils import ModelMixin +from ..utils import logging +from .controlnet import BaseOutput +from .embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed +from .transformers.transformer_2d import Transformer2DModelOutput + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class SD3ControlNetOutput(BaseOutput): + controlnet_block_samples: Tuple[ms.Tensor] + + +class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin): + _supports_gradient_checkpointing = False # not supported now + + @register_to_config + def __init__( + self, + sample_size: int = 128, + patch_size: int = 2, + in_channels: int = 16, + num_layers: int = 18, + attention_head_dim: int = 64, + num_attention_heads: int = 18, + joint_attention_dim: int = 4096, + caption_projection_dim: int = 1152, + pooled_projection_dim: int = 2048, + out_channels: int = 16, + pos_embed_max_size: int = 96, + ): + super().__init__() + default_out_channels = in_channels + self.out_channels = out_channels if out_channels is not None else default_out_channels + self.inner_dim = num_attention_heads * attention_head_dim + + self.pos_embed = PatchEmbed( + height=sample_size, + width=sample_size, + patch_size=patch_size, + in_channels=in_channels, + embed_dim=self.inner_dim, + pos_embed_max_size=pos_embed_max_size, + ) + self.time_text_embed = CombinedTimestepTextProjEmbeddings( + embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim + ) + self.context_embedder = nn.Dense(joint_attention_dim, caption_projection_dim) + + # `attention_head_dim` is doubled to account for the mixing. + # It needs to crafted when we get the actual checkpoints. + self.transformer_blocks = nn.CellList( + [ + JointTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=self.inner_dim, + context_pre_only=False, + ) + for i in range(num_layers) + ] + ) + + # controlnet_blocks + self.controlnet_blocks = [] + for _ in range(len(self.transformer_blocks)): + controlnet_block = nn.Dense( + self.inner_dim, + self.inner_dim, + weight_init="zeros", + bias_init="zeros", + ) # zero_module + self.controlnet_blocks.append(controlnet_block) + self.controlnet_blocks = nn.CellList(self.controlnet_blocks) + + self.pos_embed_input = PatchEmbed( + height=sample_size, + width=sample_size, + patch_size=patch_size, + in_channels=in_channels, + embed_dim=self.inner_dim, + pos_embed_type=None, + zero_module=True, + ) # zero module, FIXME: only conv2d zero + + self.gradient_checkpointing = False + + # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking + def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: + """ + Sets the attention processor to use [feed forward + chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). + + Parameters: + chunk_size (`int`, *optional*): + The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually + over each tensor of dim=`dim`. + dim (`int`, *optional*, defaults to `0`): + The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) + or dim=1 (sequence length). + """ + if dim not in [0, 1]: + raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") + + # By default chunk size is 1 + chunk_size = chunk_size or 1 + + def fn_recursive_feed_forward(module: nn.Cell, chunk_size: int, dim: int): + if hasattr(module, "set_chunk_feed_forward"): + module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) + + for child in module.name_cells().values(): + fn_recursive_feed_forward(child, chunk_size, dim) + + for module in self.name_cells().values(): + fn_recursive_feed_forward(module, chunk_size, dim) + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: nn.Cell, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.name_cells().items(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.name_cells().items(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: nn.Cell, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.name_cells().items(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.name_cells().items(): + fn_recursive_attn_processor(name, module, processor) + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + @classmethod + def from_transformer(cls, transformer, num_layers=None, load_weights_from_transformer=True): + config = transformer.config + config["num_layers"] = num_layers or config.num_layers + controlnet = cls(**config) + + if load_weights_from_transformer: + ms.load_param_into_net(controlnet.pos_embed, transformer.pos_embed.parameters_dict()) + ms.load_param_into_net(controlnet.time_text_embed, transformer.time_text_embed.parameters_dict()) + ms.load_param_into_net(controlnet.context_embedder, transformer.context_embedder.parameters_dict()) + ms.load_param_into_net(controlnet.transformer_blocks, transformer.transformer_blocks.parameters_dict()) + + # No `zero_module` here for it is done in cls.__init__ + # controlnet.pos_embed_input = zero_module(controlnet.pos_embed_input) + + return controlnet + + def construct( + self, + hidden_states: ms.Tensor, + controlnet_cond: ms.Tensor, + conditioning_scale: float = 1.0, + encoder_hidden_states: ms.Tensor = None, + pooled_projections: ms.Tensor = None, + timestep: ms.Tensor = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = False, + ) -> Union[ms.Tensor, Transformer2DModelOutput]: + """ + The [`SD3Transformer2DModel`] forward method. + + Args: + hidden_states (`ms.Tensor` of shape `(batch size, channel, height, width)`): + Input `hidden_states`. + controlnet_cond (`ms.Tensor`): + The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. + conditioning_scale (`float`, defaults to `1.0`): + The scale factor for ControlNet outputs. + encoder_hidden_states (`ms.Tensor` of shape `(batch size, sequence_len, embed_dims)`): + Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. + pooled_projections (`ms.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected + from the embeddings of input conditions. + timestep ( `ms.Tensor`): + Used to indicate denoising step. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + return_dict (`bool`, *optional*, defaults to `False`): + Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: + # weight the lora layers by setting `lora_scale` for each PEFT layer here + # and remove `lora_scale` from each PEFT layer at the end. + # scale_lora_layers & unscale_lora_layers maybe contains some operation forbidden in graph mode + raise RuntimeError( + f"You are trying to set scaling of lora layer by passing {joint_attention_kwargs['scale']}. " + f"However it's not allowed in on-the-fly model forwarding. " + f"Please manually call `scale_lora_layers(model, lora_scale)` before model forwarding and " + f"`unscale_lora_layers(model, lora_scale)` after model forwarding. " + f"For example, it can be done in a pipeline call like `StableDiffusionPipeline.__call__`." + ) + + height, width = hidden_states.shape[-2:] + + hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too. + temb = self.time_text_embed(timestep, pooled_projections) + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + + # add + hidden_states = hidden_states + self.pos_embed_input(controlnet_cond) + + block_res_samples = () + + for block in self.transformer_blocks: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb + ) + + block_res_samples = block_res_samples + (hidden_states,) + + controlnet_block_res_samples = () + for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks): + block_res_sample = controlnet_block(block_res_sample) + controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,) + + # 6. scaling + controlnet_block_res_samples = [sample * conditioning_scale for sample in controlnet_block_res_samples] + + if not return_dict: + return (controlnet_block_res_samples,) + + return SD3ControlNetOutput(controlnet_block_samples=controlnet_block_res_samples) + + +class SD3MultiControlNetModel(ModelMixin): + r""" + `SD3ControlNetModel` wrapper class for Multi-SD3ControlNet + + This module is a wrapper for multiple instances of the `SD3ControlNetModel`. The `forward()` API is designed to be + compatible with `SD3ControlNetModel`. + + Args: + controlnets (`List[SD3ControlNetModel]`): + Provides additional conditioning to the unet during the denoising process. You must set multiple + `SD3ControlNetModel` as a list. + """ + + def __init__(self, controlnets): + super().__init__() + self.nets = nn.CellList(controlnets) + + def construct( + self, + hidden_states: ms.Tensor, + controlnet_cond: List[ms.Tensor], + conditioning_scale: List[float], + pooled_projections: ms.Tensor, + encoder_hidden_states: ms.Tensor = None, + timestep: ms.Tensor = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = False, + ) -> Union[SD3ControlNetOutput, Tuple]: + for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)): + block_samples = controlnet( + hidden_states=hidden_states, + timestep=timestep, + encoder_hidden_states=encoder_hidden_states, + pooled_projections=pooled_projections, + controlnet_cond=image, + conditioning_scale=scale, + joint_attention_kwargs=joint_attention_kwargs, + return_dict=return_dict, + ) + + # merge samples + if i == 0: + control_block_samples = block_samples + else: + control_block_samples = [ + control_block_sample + block_sample + for control_block_sample, block_sample in zip(control_block_samples[0], block_samples[0]) + ] + control_block_samples = (tuple(control_block_samples),) + + return control_block_samples diff --git a/mindone/diffusers/models/controlnet_xs.py b/mindone/diffusers/models/controlnet_xs.py new file mode 100644 index 0000000000..03204c57c4 --- /dev/null +++ b/mindone/diffusers/models/controlnet_xs.py @@ -0,0 +1,1843 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import OrderedDict +from dataclasses import dataclass +from math import gcd +from typing import Any, Dict, List, Optional, Tuple, Union + +import mindspore as ms +from mindspore import nn, ops + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput, logging +from .activations import SiLU +from .attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor +from .controlnet import ControlNetConditioningEmbedding +from .embeddings import TimestepEmbedding, Timesteps +from .modeling_utils import ModelMixin +from .normalization import GroupNorm +from .unets.unet_2d_blocks import ( + CrossAttnDownBlock2D, + CrossAttnUpBlock2D, + Downsample2D, + ResnetBlock2D, + Transformer2DModel, + UNetMidBlock2DCrossAttn, + Upsample2D, +) +from .unets.unet_2d_condition import UNet2DConditionModel + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class ControlNetXSOutput(BaseOutput): + """ + The output of [`UNetControlNetXSModel`]. + + Args: + sample (`Tensor` of shape `(batch_size, num_channels, height, width)`): + The output of the `UNetControlNetXSModel`. Unlike `ControlNetOutput` this is NOT to be added to the base + model output, but is already the final output. + """ + + sample: ms.Tensor = None + + +class DownBlockControlNetXSAdapter(nn.Cell): + """Components that together with corresponding components from the base model will form a + `ControlNetXSCrossAttnDownBlock2D`""" + + def __init__( + self, + resnets: nn.CellList, + base_to_ctrl: nn.CellList, + ctrl_to_base: nn.CellList, + attentions: Optional[nn.CellList] = None, + downsampler: Optional[nn.Conv2d] = None, + ): + super().__init__() + self.resnets = resnets + self.base_to_ctrl = base_to_ctrl + self.ctrl_to_base = ctrl_to_base + self.attentions = attentions + self.downsamplers = downsampler + + +class MidBlockControlNetXSAdapter(nn.Cell): + """Components that together with corresponding components from the base model will form a + `ControlNetXSCrossAttnMidBlock2D`""" + + def __init__(self, midblock: UNetMidBlock2DCrossAttn, base_to_ctrl: nn.CellList, ctrl_to_base: nn.CellList): + super().__init__() + self.midblock = midblock + self.base_to_ctrl = base_to_ctrl + self.ctrl_to_base = ctrl_to_base + + +class UpBlockControlNetXSAdapter(nn.Cell): + """Components that together with corresponding components from the base model will form a `ControlNetXSCrossAttnUpBlock2D`""" + + def __init__(self, ctrl_to_base: nn.CellList): + super().__init__() + self.ctrl_to_base = ctrl_to_base + + +def get_down_block_adapter( + base_in_channels: int, + base_out_channels: int, + ctrl_in_channels: int, + ctrl_out_channels: int, + temb_channels: int, + max_norm_num_groups: Optional[int] = 32, + has_crossattn=True, + transformer_layers_per_block: Optional[Union[int, Tuple[int]]] = 1, + num_attention_heads: Optional[int] = 1, + cross_attention_dim: Optional[int] = 1024, + add_downsample: bool = True, + upcast_attention: Optional[bool] = False, +): + num_layers = 2 # only support sd + sdxl + + resnets = [] + attentions = [] + ctrl_to_base = [] + base_to_ctrl = [] + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers + + for i in range(num_layers): + base_in_channels = base_in_channels if i == 0 else base_out_channels + ctrl_in_channels = ctrl_in_channels if i == 0 else ctrl_out_channels + + # Before the resnet/attention application, information is concatted from base to control. + # Concat doesn't require change in number of channels + base_to_ctrl.append(make_zero_conv(base_in_channels, base_in_channels)) + + resnets.append( + ResnetBlock2D( + in_channels=ctrl_in_channels + base_in_channels, # information from base is concatted to ctrl + out_channels=ctrl_out_channels, + temb_channels=temb_channels, + groups=find_largest_factor(ctrl_in_channels + base_in_channels, max_factor=max_norm_num_groups), + groups_out=find_largest_factor(ctrl_out_channels, max_factor=max_norm_num_groups), + eps=1e-5, + ) + ) + + if has_crossattn: + attentions.append( + Transformer2DModel( + num_attention_heads, + ctrl_out_channels // num_attention_heads, + in_channels=ctrl_out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + use_linear_projection=True, + upcast_attention=upcast_attention, + norm_num_groups=find_largest_factor(ctrl_out_channels, max_factor=max_norm_num_groups), + ) + ) + + # After the resnet/attention application, information is added from control to base + # Addition requires change in number of channels + ctrl_to_base.append(make_zero_conv(ctrl_out_channels, base_out_channels)) + + if add_downsample: + # Before the downsampler application, information is concatted from base to control + # Concat doesn't require change in number of channels + base_to_ctrl.append(make_zero_conv(base_out_channels, base_out_channels)) + + downsamplers = Downsample2D( + ctrl_out_channels + base_out_channels, use_conv=True, out_channels=ctrl_out_channels, name="op" + ) + + # After the downsampler application, information is added from control to base + # Addition requires change in number of channels + ctrl_to_base.append(make_zero_conv(ctrl_out_channels, base_out_channels)) + else: + downsamplers = None + + down_block_components = DownBlockControlNetXSAdapter( + resnets=nn.CellList(resnets), + base_to_ctrl=nn.CellList(base_to_ctrl), + ctrl_to_base=nn.CellList(ctrl_to_base), + ) + + if has_crossattn: + down_block_components.attentions = nn.CellList(attentions) + if downsamplers is not None: + down_block_components.downsamplers = downsamplers + + return down_block_components + + +def get_mid_block_adapter( + base_channels: int, + ctrl_channels: int, + temb_channels: Optional[int] = None, + max_norm_num_groups: Optional[int] = 32, + transformer_layers_per_block: int = 1, + num_attention_heads: Optional[int] = 1, + cross_attention_dim: Optional[int] = 1024, + upcast_attention: bool = False, +): + # Before the midblock application, information is concatted from base to control. + # Concat doesn't require change in number of channels + base_to_ctrl = make_zero_conv(base_channels, base_channels) + + midblock = UNetMidBlock2DCrossAttn( + transformer_layers_per_block=transformer_layers_per_block, + in_channels=ctrl_channels + base_channels, + out_channels=ctrl_channels, + temb_channels=temb_channels, + # number or norm groups must divide both in_channels and out_channels + resnet_groups=find_largest_factor(gcd(ctrl_channels, ctrl_channels + base_channels), max_norm_num_groups), + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + use_linear_projection=True, + upcast_attention=upcast_attention, + ) + + # After the midblock application, information is added from control to base + # Addition requires change in number of channels + ctrl_to_base = make_zero_conv(ctrl_channels, base_channels) + + return MidBlockControlNetXSAdapter(base_to_ctrl=base_to_ctrl, midblock=midblock, ctrl_to_base=ctrl_to_base) + + +def get_up_block_adapter( + out_channels: int, + prev_output_channel: int, + ctrl_skip_channels: List[int], +): + ctrl_to_base = [] + num_layers = 3 # only support sd + sdxl + for i in range(num_layers): + resnet_in_channels = prev_output_channel if i == 0 else out_channels + ctrl_to_base.append(make_zero_conv(ctrl_skip_channels[i], resnet_in_channels)) + + return UpBlockControlNetXSAdapter(ctrl_to_base=nn.CellList(ctrl_to_base)) + + +class ControlNetXSAdapter(ModelMixin, ConfigMixin): + r""" + A `ControlNetXSAdapter` model. To use it, pass it into a `UNetControlNetXSModel` (together with a + `UNet2DConditionModel` base model). + + This model inherits from [`ModelMixin`] and [`ConfigMixin`]. Check the superclass documentation for it's generic + methods implemented for all models (such as downloading or saving). + + Like `UNetControlNetXSModel`, `ControlNetXSAdapter` is compatible with StableDiffusion and StableDiffusion-XL. It's + default parameters are compatible with StableDiffusion. + + Parameters: + conditioning_channels (`int`, defaults to 3): + Number of channels of conditioning input (e.g. an image) + conditioning_channel_order (`str`, defaults to `"rgb"`): + The channel order of conditional image. Will convert to `rgb` if it's `bgr`. + conditioning_embedding_out_channels (`tuple[int]`, defaults to `(16, 32, 96, 256)`): + The tuple of output channels for each block in the `controlnet_cond_embedding` layer. + time_embedding_mix (`float`, defaults to 1.0): + If 0, then only the control adapters's time embedding is used. If 1, then only the base unet's time + embedding is used. Otherwise, both are combined. + learn_time_embedding (`bool`, defaults to `False`): + Whether a time embedding should be learned. If yes, `UNetControlNetXSModel` will combine the time + embeddings of the base model and the control adapter. If no, `UNetControlNetXSModel` will use the base + model's time embedding. + num_attention_heads (`list[int]`, defaults to `[4]`): + The number of attention heads. + block_out_channels (`list[int]`, defaults to `[4, 8, 16, 16]`): + The tuple of output channels for each block. + base_block_out_channels (`list[int]`, defaults to `[320, 640, 1280, 1280]`): + The tuple of output channels for each block in the base unet. + cross_attention_dim (`int`, defaults to 1024): + The dimension of the cross attention features. + down_block_types (`list[str]`, defaults to `["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"]`): + The tuple of downsample blocks to use. + sample_size (`int`, defaults to 96): + Height and width of input/output sample. + transformer_layers_per_block (`Union[int, Tuple[int]]`, defaults to 1): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for + [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. + upcast_attention (`bool`, defaults to `True`): + Whether the attention computation should always be upcasted. + max_norm_num_groups (`int`, defaults to 32): + Maximum number of groups in group normal. The actual number will the the largest divisor of the respective + channels, that is <= max_norm_num_groups. + """ + + @register_to_config + def __init__( + self, + conditioning_channels: int = 3, + conditioning_channel_order: str = "rgb", + conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256), + time_embedding_mix: float = 1.0, + learn_time_embedding: bool = False, + num_attention_heads: Union[int, Tuple[int]] = 4, + block_out_channels: Tuple[int] = (4, 8, 16, 16), + base_block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + cross_attention_dim: int = 1024, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ), + sample_size: Optional[int] = 96, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + upcast_attention: bool = True, + max_norm_num_groups: int = 32, + ): + super().__init__() + + time_embedding_input_dim = base_block_out_channels[0] + time_embedding_dim = base_block_out_channels[0] * 4 + + # Check inputs + if conditioning_channel_order not in ["rgb", "bgr"]: + raise ValueError(f"unknown `conditioning_channel_order`: {conditioning_channel_order}") + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." # noqa: E501 + ) + + if not isinstance(transformer_layers_per_block, (list, tuple)): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + if not isinstance(cross_attention_dim, (list, tuple)): + cross_attention_dim = [cross_attention_dim] * len(down_block_types) + # see https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why `ControlNetXSAdapter` takes `num_attention_heads` instead of `attention_head_dim` # noqa: E501 + if not isinstance(num_attention_heads, (list, tuple)): + num_attention_heads = [num_attention_heads] * len(down_block_types) + + if len(num_attention_heads) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." # noqa: E501 + ) + + # 5 - Create conditioning hint embedding + self.controlnet_cond_embedding = ControlNetConditioningEmbedding( + conditioning_embedding_channels=block_out_channels[0], + block_out_channels=conditioning_embedding_out_channels, + conditioning_channels=conditioning_channels, + ) + + # time + if learn_time_embedding: + self.time_embedding = TimestepEmbedding(time_embedding_input_dim, time_embedding_dim) + else: + self.time_embedding = None + + self.down_blocks = [] + self.up_connections = [] + + # input + self.conv_in = nn.Conv2d(4, block_out_channels[0], kernel_size=3, padding=1, pad_mode="pad", has_bias=True) + self.control_to_base_for_conv_in = make_zero_conv(block_out_channels[0], base_block_out_channels[0]) + + # down + base_out_channels = base_block_out_channels[0] + ctrl_out_channels = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + base_in_channels = base_out_channels + base_out_channels = base_block_out_channels[i] + ctrl_in_channels = ctrl_out_channels + ctrl_out_channels = block_out_channels[i] + has_crossattn = "CrossAttn" in down_block_type + is_final_block = i == len(down_block_types) - 1 + + self.down_blocks.append( + get_down_block_adapter( + base_in_channels=base_in_channels, + base_out_channels=base_out_channels, + ctrl_in_channels=ctrl_in_channels, + ctrl_out_channels=ctrl_out_channels, + temb_channels=time_embedding_dim, + max_norm_num_groups=max_norm_num_groups, + has_crossattn=has_crossattn, + transformer_layers_per_block=transformer_layers_per_block[i], + num_attention_heads=num_attention_heads[i], + cross_attention_dim=cross_attention_dim[i], + add_downsample=not is_final_block, + upcast_attention=upcast_attention, + ) + ) + self.down_blocks = nn.CellList(self.down_blocks) + + # mid + self.mid_block = get_mid_block_adapter( + base_channels=base_block_out_channels[-1], + ctrl_channels=block_out_channels[-1], + temb_channels=time_embedding_dim, + transformer_layers_per_block=transformer_layers_per_block[-1], + num_attention_heads=num_attention_heads[-1], + cross_attention_dim=cross_attention_dim[-1], + upcast_attention=upcast_attention, + ) + + # up + # The skip connection channels are the output of the conv_in and of all the down subblocks + ctrl_skip_channels = [block_out_channels[0]] + for i, out_channels in enumerate(block_out_channels): + number_of_subblocks = ( + 3 if i < len(block_out_channels) - 1 else 2 + ) # every block has 3 subblocks, except last one, which has 2 as it has no downsampler + ctrl_skip_channels.extend([out_channels] * number_of_subblocks) + + reversed_base_block_out_channels = list(reversed(base_block_out_channels)) + + base_out_channels = reversed_base_block_out_channels[0] + for i in range(len(down_block_types)): + prev_base_output_channel = base_out_channels + base_out_channels = reversed_base_block_out_channels[i] + ctrl_skip_channels_ = [ctrl_skip_channels.pop() for _ in range(3)] + + self.up_connections.append( + get_up_block_adapter( + out_channels=base_out_channels, + prev_output_channel=prev_base_output_channel, + ctrl_skip_channels=ctrl_skip_channels_, + ) + ) + self.up_connections = nn.CellList(self.up_connections) + + @classmethod + def from_unet( + cls, + unet: UNet2DConditionModel, + size_ratio: Optional[float] = None, + block_out_channels: Optional[List[int]] = None, + num_attention_heads: Optional[List[int]] = None, + learn_time_embedding: bool = False, + time_embedding_mix: int = 1.0, + conditioning_channels: int = 3, + conditioning_channel_order: str = "rgb", + conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256), + ): + r""" + Instantiate a [`ControlNetXSAdapter`] from a [`UNet2DConditionModel`]. + + Parameters: + unet (`UNet2DConditionModel`): + The UNet model we want to control. The dimensions of the ControlNetXSAdapter will be adapted to it. + size_ratio (float, *optional*, defaults to `None`): + When given, block_out_channels is set to a fraction of the base model's block_out_channels. Either this + or `block_out_channels` must be given. + block_out_channels (`List[int]`, *optional*, defaults to `None`): + Down blocks output channels in control model. Either this or `size_ratio` must be given. + num_attention_heads (`List[int]`, *optional*, defaults to `None`): + The dimension of the attention heads. The naming seems a bit confusing and it is, see + https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why. + learn_time_embedding (`bool`, defaults to `False`): + Whether the `ControlNetXSAdapter` should learn a time embedding. + time_embedding_mix (`float`, defaults to 1.0): + If 0, then only the control adapter's time embedding is used. If 1, then only the base unet's time + embedding is used. Otherwise, both are combined. + conditioning_channels (`int`, defaults to 3): + Number of channels of conditioning input (e.g. an image) + conditioning_channel_order (`str`, defaults to `"rgb"`): + The channel order of conditional image. Will convert to `rgb` if it's `bgr`. + conditioning_embedding_out_channels (`Tuple[int]`, defaults to `(16, 32, 96, 256)`): + The tuple of output channel for each block in the `controlnet_cond_embedding` layer. + """ + + # Check input + fixed_size = block_out_channels is not None + relative_size = size_ratio is not None + if not (fixed_size ^ relative_size): + raise ValueError( + "Pass exactly one of `block_out_channels` (for absolute sizing) or `size_ratio` (for relative sizing)." + ) + + # Create model + block_out_channels = block_out_channels or [int(b * size_ratio) for b in unet.config.block_out_channels] + if num_attention_heads is None: + # The naming seems a bit confusing and it is, see https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why. # noqa: E501 + num_attention_heads = unet.config.attention_head_dim + + model = cls( + conditioning_channels=conditioning_channels, + conditioning_channel_order=conditioning_channel_order, + conditioning_embedding_out_channels=conditioning_embedding_out_channels, + time_embedding_mix=time_embedding_mix, + learn_time_embedding=learn_time_embedding, + num_attention_heads=num_attention_heads, + block_out_channels=block_out_channels, + base_block_out_channels=unet.config.block_out_channels, + cross_attention_dim=unet.config.cross_attention_dim, + down_block_types=unet.config.down_block_types, + sample_size=unet.config.sample_size, + transformer_layers_per_block=unet.config.transformer_layers_per_block, + upcast_attention=unet.config.upcast_attention, + max_norm_num_groups=unet.config.norm_num_groups, + ) + + # ensure that the ControlNetXSAdapter is the same dtype as the UNet2DConditionModel + model.to(unet.dtype) + + return model + + def construct(self, *args, **kwargs): + raise ValueError( + "A ControlNetXSAdapter cannot be run by itself. Use it together with a UNet2DConditionModel to instantiate a UNetControlNetXSModel." # noqa: E501 + ) + + +class UNetControlNetXSModel(ModelMixin, ConfigMixin): + r""" + A UNet fused with a ControlNet-XS adapter model + + This model inherits from [`ModelMixin`] and [`ConfigMixin`]. Check the superclass documentation for it's generic + methods implemented for all models (such as downloading or saving). + + `UNetControlNetXSModel` is compatible with StableDiffusion and StableDiffusion-XL. It's default parameters are + compatible with StableDiffusion. + + It's parameters are either passed to the underlying `UNet2DConditionModel` or used exactly like in + `ControlNetXSAdapter` . See their documentation for details. + """ + + _supports_gradient_checkpointing = False # not supported now + + @register_to_config + def __init__( + self, + # unet configs + sample_size: Optional[int] = 96, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ), + up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + norm_num_groups: Optional[int] = 32, + cross_attention_dim: Union[int, Tuple[int]] = 1024, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + num_attention_heads: Union[int, Tuple[int]] = 8, + addition_embed_type: Optional[str] = None, + addition_time_embed_dim: Optional[int] = None, + upcast_attention: bool = True, + time_cond_proj_dim: Optional[int] = None, + projection_class_embeddings_input_dim: Optional[int] = None, + # additional controlnet configs + time_embedding_mix: float = 1.0, + ctrl_conditioning_channels: int = 3, + ctrl_conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256), + ctrl_conditioning_channel_order: str = "rgb", + ctrl_learn_time_embedding: bool = False, + ctrl_block_out_channels: Tuple[int] = (4, 8, 16, 16), + ctrl_num_attention_heads: Union[int, Tuple[int]] = 4, + ctrl_max_norm_num_groups: int = 32, + ): + super().__init__() + + if time_embedding_mix < 0 or time_embedding_mix > 1: + raise ValueError("`time_embedding_mix` needs to be between 0 and 1.") + if time_embedding_mix < 1 and not ctrl_learn_time_embedding: + raise ValueError("To use `time_embedding_mix` < 1, `ctrl_learn_time_embedding` must be `True`") + + if addition_embed_type is not None and addition_embed_type != "text_time": + raise ValueError( + "As `UNetControlNetXSModel` currently only supports StableDiffusion and StableDiffusion-XL, `addition_embed_type` must be `None` or `'text_time'`." # noqa: E501 + ) + + if not isinstance(transformer_layers_per_block, (list, tuple)): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + if not isinstance(cross_attention_dim, (list, tuple)): + cross_attention_dim = [cross_attention_dim] * len(down_block_types) + if not isinstance(num_attention_heads, (list, tuple)): + num_attention_heads = [num_attention_heads] * len(down_block_types) + if not isinstance(ctrl_num_attention_heads, (list, tuple)): + ctrl_num_attention_heads = [ctrl_num_attention_heads] * len(down_block_types) + + base_num_attention_heads = num_attention_heads + + self.in_channels = 4 + + # # Input + self.base_conv_in = nn.Conv2d(4, block_out_channels[0], kernel_size=3, padding=1, pad_mode="pad", has_bias=True) + self.controlnet_cond_embedding = ControlNetConditioningEmbedding( + conditioning_embedding_channels=ctrl_block_out_channels[0], + block_out_channels=ctrl_conditioning_embedding_out_channels, + conditioning_channels=ctrl_conditioning_channels, + ) + self.ctrl_conv_in = nn.Conv2d( + 4, ctrl_block_out_channels[0], kernel_size=3, padding=1, pad_mode="pad", has_bias=True + ) + self.control_to_base_for_conv_in = make_zero_conv(ctrl_block_out_channels[0], block_out_channels[0]) + + # # Time + time_embed_input_dim = block_out_channels[0] + time_embed_dim = block_out_channels[0] * 4 + + self.base_time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos=True, downscale_freq_shift=0) + self.base_time_embedding = TimestepEmbedding( + time_embed_input_dim, + time_embed_dim, + cond_proj_dim=time_cond_proj_dim, + ) + self.ctrl_time_embedding = TimestepEmbedding(in_channels=time_embed_input_dim, time_embed_dim=time_embed_dim) + + if addition_embed_type is None: + self.base_add_time_proj = None + self.base_add_embedding = None + else: + self.base_add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos=True, downscale_freq_shift=0) + self.base_add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + + # # Create down blocks + down_blocks = [] + base_out_channels = block_out_channels[0] + ctrl_out_channels = ctrl_block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + base_in_channels = base_out_channels + base_out_channels = block_out_channels[i] + ctrl_in_channels = ctrl_out_channels + ctrl_out_channels = ctrl_block_out_channels[i] + has_crossattn = "CrossAttn" in down_block_type + is_final_block = i == len(down_block_types) - 1 + + down_blocks.append( + ControlNetXSCrossAttnDownBlock2D( + base_in_channels=base_in_channels, + base_out_channels=base_out_channels, + ctrl_in_channels=ctrl_in_channels, + ctrl_out_channels=ctrl_out_channels, + temb_channels=time_embed_dim, + norm_num_groups=norm_num_groups, + ctrl_max_norm_num_groups=ctrl_max_norm_num_groups, + has_crossattn=has_crossattn, + transformer_layers_per_block=transformer_layers_per_block[i], + base_num_attention_heads=base_num_attention_heads[i], + ctrl_num_attention_heads=ctrl_num_attention_heads[i], + cross_attention_dim=cross_attention_dim[i], + add_downsample=not is_final_block, + upcast_attention=upcast_attention, + ) + ) + + # # Create mid block + self.mid_block = ControlNetXSCrossAttnMidBlock2D( + base_channels=block_out_channels[-1], + ctrl_channels=ctrl_block_out_channels[-1], + temb_channels=time_embed_dim, + norm_num_groups=norm_num_groups, + ctrl_max_norm_num_groups=ctrl_max_norm_num_groups, + transformer_layers_per_block=transformer_layers_per_block[-1], + base_num_attention_heads=base_num_attention_heads[-1], + ctrl_num_attention_heads=ctrl_num_attention_heads[-1], + cross_attention_dim=cross_attention_dim[-1], + upcast_attention=upcast_attention, + ) + + # # Create up blocks + up_blocks = [] + rev_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) + rev_num_attention_heads = list(reversed(base_num_attention_heads)) + rev_cross_attention_dim = list(reversed(cross_attention_dim)) + + # The skip connection channels are the output of the conv_in and of all the down subblocks + ctrl_skip_channels = [ctrl_block_out_channels[0]] + for i, out_channels in enumerate(ctrl_block_out_channels): + number_of_subblocks = ( + 3 if i < len(ctrl_block_out_channels) - 1 else 2 + ) # every block has 3 subblocks, except last one, which has 2 as it has no downsampler + ctrl_skip_channels.extend([out_channels] * number_of_subblocks) + + reversed_block_out_channels = list(reversed(block_out_channels)) + + out_channels = reversed_block_out_channels[0] + up_blocks_resnets_lens = [] + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = out_channels + out_channels = reversed_block_out_channels[i] + in_channels = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + ctrl_skip_channels_ = [ctrl_skip_channels.pop() for _ in range(3)] + + has_crossattn = "CrossAttn" in up_block_type + is_final_block = i == len(block_out_channels) - 1 + + up_blocks.append( + ControlNetXSCrossAttnUpBlock2D( + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + ctrl_skip_channels=ctrl_skip_channels_, + temb_channels=time_embed_dim, + resolution_idx=i, + has_crossattn=has_crossattn, + transformer_layers_per_block=rev_transformer_layers_per_block[i], + num_attention_heads=rev_num_attention_heads[i], + cross_attention_dim=rev_cross_attention_dim[i], + add_upsample=not is_final_block, + upcast_attention=upcast_attention, + norm_num_groups=norm_num_groups, + ) + ) + up_blocks_resnets_lens.append(len(up_blocks[-1].resnets)) + + self.down_blocks = nn.CellList(down_blocks) + self.up_blocks = nn.CellList(up_blocks) + self.up_blocks_resnets_lens = up_blocks_resnets_lens + + self.base_conv_norm_out = GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups) + self.base_conv_act = SiLU() + self.base_conv_out = nn.Conv2d( + block_out_channels[0], 4, kernel_size=3, padding=1, pad_mode="pad", has_bias=True + ) + + @classmethod + def from_unet( + cls, + unet: UNet2DConditionModel, + controlnet: Optional[ControlNetXSAdapter] = None, + size_ratio: Optional[float] = None, + ctrl_block_out_channels: Optional[List[float]] = None, + time_embedding_mix: Optional[float] = None, + ctrl_optional_kwargs: Optional[Dict] = None, + ): + r""" + Instantiate a [`UNetControlNetXSModel`] from a [`UNet2DConditionModel`] and an optional [`ControlNetXSAdapter`] + . + + Parameters: + unet (`UNet2DConditionModel`): + The UNet model we want to control. + controlnet (`ControlNetXSAdapter`): + The ConntrolNet-XS adapter with which the UNet will be fused. If none is given, a new ConntrolNet-XS + adapter will be created. + size_ratio (float, *optional*, defaults to `None`): + Used to contruct the controlnet if none is given. See [`ControlNetXSAdapter.from_unet`] for details. + ctrl_block_out_channels (`List[int]`, *optional*, defaults to `None`): + Used to contruct the controlnet if none is given. See [`ControlNetXSAdapter.from_unet`] for details, + where this parameter is called `block_out_channels`. + time_embedding_mix (`float`, *optional*, defaults to None): + Used to contruct the controlnet if none is given. See [`ControlNetXSAdapter.from_unet`] for details. + ctrl_optional_kwargs (`Dict`, *optional*, defaults to `None`): + Passed to the `init` of the new controlent if no controlent was given. + """ + if controlnet is None: + controlnet = ControlNetXSAdapter.from_unet( + unet, size_ratio, ctrl_block_out_channels, **ctrl_optional_kwargs + ) + else: + if any( + o is not None for o in (size_ratio, ctrl_block_out_channels, time_embedding_mix, ctrl_optional_kwargs) + ): + raise ValueError( + "When a controlnet is passed, none of these parameters should be passed: size_ratio, ctrl_block_out_channels, time_embedding_mix, ctrl_optional_kwargs." # noqa: E501 + ) + + # # get params + params_for_unet = [ + "sample_size", + "down_block_types", + "up_block_types", + "block_out_channels", + "norm_num_groups", + "cross_attention_dim", + "transformer_layers_per_block", + "addition_embed_type", + "addition_time_embed_dim", + "upcast_attention", + "time_cond_proj_dim", + "projection_class_embeddings_input_dim", + ] + params_for_unet = {k: v for k, v in unet.config.items() if k in params_for_unet} + # The naming seems a bit confusing and it is, see https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why. + params_for_unet["num_attention_heads"] = unet.config.attention_head_dim + + params_for_controlnet = [ + "conditioning_channels", + "conditioning_embedding_out_channels", + "conditioning_channel_order", + "learn_time_embedding", + "block_out_channels", + "num_attention_heads", + "max_norm_num_groups", + ] + params_for_controlnet = {"ctrl_" + k: v for k, v in controlnet.config.items() if k in params_for_controlnet} + params_for_controlnet["time_embedding_mix"] = controlnet.config.time_embedding_mix + + # create model and cast dtype to ensure conbined model has the same dtype as the unet and controlnet + # and avoid tons of warning when MindSpore load parameters in state dict to those in model + model = cls.from_config({**params_for_unet, **params_for_controlnet}) + model.to(unet.dtype) + + # # load weights + # from unet + modules_from_unet = [ + "time_embedding", + "conv_in", + "conv_norm_out", + "conv_out", + ] + for m in modules_from_unet: + ms.load_param_into_net(getattr(model, "base_" + m), get_state_dict(getattr(unet, m), "base_" + m)) + + optional_modules_from_unet = [ + "add_time_proj", + "add_embedding", + ] + for m in optional_modules_from_unet: + if hasattr(unet, m) and getattr(unet, m) is not None: + ms.load_param_into_net(getattr(model, "base_" + m), get_state_dict(getattr(unet, m), "base_" + m)) + + # from controlnet + ms.load_param_into_net(model.controlnet_cond_embedding, controlnet.controlnet_cond_embedding.parameters_dict()) + # ms.load_param_into_net(model.ctrl_conv_in, controlnet.conv_in.parameters_dict()) + ms.load_param_into_net(model.ctrl_conv_in, get_state_dict(controlnet.conv_in, "ctrl_conv_in")) + if controlnet.time_embedding is not None: + # ms.load_param_into_net(model.ctrl_time_embedding, controlnet.time_embedding.parameters_dict()) + ms.load_param_into_net( + model.ctrl_time_embedding, get_state_dict(controlnet.time_embedding, "ctrl_time_embedding") + ) + ms.load_param_into_net( + model.control_to_base_for_conv_in, controlnet.control_to_base_for_conv_in.parameters_dict() + ) + + # from both + model.down_blocks = nn.CellList( + [ + ControlNetXSCrossAttnDownBlock2D.from_modules(b, c) + for b, c in zip(unet.down_blocks, controlnet.down_blocks) + ] + ) + model.mid_block = ControlNetXSCrossAttnMidBlock2D.from_modules(unet.mid_block, controlnet.mid_block) + model.up_blocks = nn.CellList( + [ + ControlNetXSCrossAttnUpBlock2D.from_modules(b, c) + for b, c in zip(unet.up_blocks, controlnet.up_connections) + ] + ) + + # ensure that the UNetControlNetXSModel is the same dtype as the UNet2DConditionModel + model.to(unet.dtype) + + return model + + def freeze_unet_params(self) -> None: + """Freeze the weights of the parts belonging to the base UNet2DConditionModel, and leave everything else unfrozen for fine + tuning.""" + # Freeze everything + for param in self.get_parameters(): + param.requires_grad = True + + # Unfreeze ControlNetXSAdapter + base_parts = [ + "base_time_proj", + "base_time_embedding", + "base_add_time_proj", + "base_add_embedding", + "base_conv_in", + "base_conv_norm_out", + "base_conv_act", + "base_conv_out", + ] + base_parts = [getattr(self, part) for part in base_parts if getattr(self, part) is not None] + for part in base_parts: + for param in part.get_parameters(): + param.requires_grad = False + + for d in self.down_blocks: + d.freeze_base_params() + self.mid_block.freeze_base_params() + for u in self.up_blocks: + u.freeze_base_params() + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: nn.Cell, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.name_cells().items(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.name_cells().items(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: nn.Cell, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.name_cells().items(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.name_cells().items(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor) + + def construct( + self, + sample: ms.Tensor, + timestep: Union[ms.Tensor, float, int], + encoder_hidden_states: ms.Tensor, + controlnet_cond: Optional[ms.Tensor] = None, + conditioning_scale: Optional[float] = 1.0, + class_labels: Optional[ms.Tensor] = None, + timestep_cond: Optional[ms.Tensor] = None, + attention_mask: Optional[ms.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + added_cond_kwargs: Optional[Dict[str, ms.Tensor]] = None, + return_dict: bool = False, + apply_control: bool = True, + ) -> Union[ControlNetXSOutput, Tuple]: + """ + The [`ControlNetXSModel`] forward method. + + Args: + sample (`Tensor`): + The noisy input tensor. + timestep (`Union[ms.Tensor, float, int]`): + The number of timesteps to denoise an input. + encoder_hidden_states (`ms.Tensor`): + The encoder hidden states. + controlnet_cond (`Tensor`): + The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. + conditioning_scale (`float`, defaults to `1.0`): + How much the control model affects the base model outputs. + class_labels (`ms.Tensor`, *optional*, defaults to `None`): + Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. + timestep_cond (`ms.Tensor`, *optional*, defaults to `None`): + Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the + timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep + embeddings. + attention_mask (`ms.Tensor`, *optional*, defaults to `None`): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`): + A kwargs dictionary that if specified is passed along to the `AttnProcessor`. + added_cond_kwargs (`dict`): + Additional conditions for the Stable Diffusion XL UNet. + return_dict (`bool`, defaults to `False`): + Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple. + apply_control (`bool`, defaults to `True`): + If `False`, the input is run only through the base model. + + Returns: + [`~models.controlnetxs.ControlNetXSOutput`] **or** `tuple`: + If `return_dict` is `True`, a [`~models.controlnetxs.ControlNetXSOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + """ + + # check channel order + if self.config["ctrl_conditioning_channel_order"] == "bgr": + controlnet_cond = ops.flip(controlnet_cond, dims=[1]) + + # prepare attention_mask + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # 1. time + timesteps = timestep + dtype = None + if not ops.is_tensor(timesteps): + if isinstance(timestep, float): + dtype = ms.float64 + else: + dtype = ms.int64 + timesteps = ms.Tensor([timesteps], dtype=dtype) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None] + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.broadcast_to((sample.shape[0],)) + + t_emb = self.base_time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) + + ctrl_temb, base_temb, interpolation_param = (None,) * 3 + if self.config["ctrl_learn_time_embedding"] and apply_control: + ctrl_temb = self.ctrl_time_embedding(t_emb, timestep_cond) + base_temb = self.base_time_embedding(t_emb, timestep_cond) + interpolation_param = self.config["time_embedding_mix"] ** 0.3 + + temb = ctrl_temb * interpolation_param + base_temb * (1 - interpolation_param) + else: + temb = self.base_time_embedding(t_emb) + + # added time & text embeddings + time_ids, time_embeds, add_embeds, aug_emb = (None,) * 4 + if self.config["addition_embed_type"] is None: + pass + elif self.config["addition_embed_type"] == "text_time": + # SDXL - style + if "text_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" # noqa: E501 + ) + text_embeds = added_cond_kwargs.get("text_embeds") + if "time_ids" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" # noqa: E501 + ) + time_ids = added_cond_kwargs.get("time_ids") + time_embeds = self.base_add_time_proj(time_ids.flatten()).to(text_embeds.dtype) + time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) + add_embeds = ops.concat([text_embeds, time_embeds], axis=-1) + add_embeds = add_embeds.to(temb.dtype) + aug_emb = self.base_add_embedding(add_embeds) + else: + raise ValueError( + f"ControlNet-XS currently only supports StableDiffusion and StableDiffusion-XL, so addition_embed_type = {self.config['addition_embed_type']} is currently not supported." # noqa: E501 + ) + + temb = temb + aug_emb if aug_emb is not None else temb + + # text embeddings + cemb = encoder_hidden_states + + # Preparation + h_ctrl = h_base = sample + hs_base, hs_ctrl = (), () + + # Cross Control + guided_hint = self.controlnet_cond_embedding(controlnet_cond) + + # 1 - conv in & down + + h_base = self.base_conv_in(h_base) + h_ctrl = self.ctrl_conv_in(h_ctrl) + if guided_hint is not None: + h_ctrl += guided_hint + if apply_control: + h_base = h_base + self.control_to_base_for_conv_in(h_ctrl) * conditioning_scale # add ctrl -> base + + hs_base += (h_base,) + hs_ctrl += (h_ctrl,) + + for down in self.down_blocks: + h_base, h_ctrl, residual_hb, residual_hc = down( + hidden_states_base=h_base, + hidden_states_ctrl=h_ctrl, + temb=temb, + encoder_hidden_states=cemb, + conditioning_scale=conditioning_scale, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + apply_control=apply_control, + ) + hs_base += residual_hb + hs_ctrl += residual_hc + + # 2 - mid + h_base, h_ctrl = self.mid_block( + hidden_states_base=h_base, + hidden_states_ctrl=h_ctrl, + temb=temb, + encoder_hidden_states=cemb, + conditioning_scale=conditioning_scale, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + apply_control=apply_control, + ) + + # 3 - up + for i, up in enumerate(self.up_blocks): + n_resnets = self.up_blocks_resnets_lens[i] + skips_hb = hs_base[-n_resnets:] + skips_hc = hs_ctrl[-n_resnets:] + hs_base = hs_base[:-n_resnets] + hs_ctrl = hs_ctrl[:-n_resnets] + h_base = up( + hidden_states=h_base, + res_hidden_states_tuple_base=skips_hb, + res_hidden_states_tuple_ctrl=skips_hc, + temb=temb, + encoder_hidden_states=cemb, + conditioning_scale=conditioning_scale, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + apply_control=apply_control, + ) + + # 4 - conv out + h_base = self.base_conv_norm_out(h_base) + h_base = self.base_conv_act(h_base) + h_base = self.base_conv_out(h_base) + + if not return_dict: + return (h_base,) + + return ControlNetXSOutput(sample=h_base) + + +class ControlNetXSCrossAttnDownBlock2D(nn.Cell): + def __init__( + self, + base_in_channels: int, + base_out_channels: int, + ctrl_in_channels: int, + ctrl_out_channels: int, + temb_channels: int, + norm_num_groups: int = 32, + ctrl_max_norm_num_groups: int = 32, + has_crossattn=True, + transformer_layers_per_block: Optional[Union[int, Tuple[int]]] = 1, + base_num_attention_heads: Optional[int] = 1, + ctrl_num_attention_heads: Optional[int] = 1, + cross_attention_dim: Optional[int] = 1024, + add_downsample: bool = True, + upcast_attention: Optional[bool] = False, + ): + super().__init__() + base_resnets = [] + base_attentions = [] + ctrl_resnets = [] + ctrl_attentions = [] + ctrl_to_base = [] + base_to_ctrl = [] + + num_layers = 2 # only support sd + sdxl + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers + + for i in range(num_layers): + base_in_channels = base_in_channels if i == 0 else base_out_channels + ctrl_in_channels = ctrl_in_channels if i == 0 else ctrl_out_channels + + # Before the resnet/attention application, information is concatted from base to control. + # Concat doesn't require change in number of channels + base_to_ctrl.append(make_zero_conv(base_in_channels, base_in_channels)) + + base_resnets.append( + ResnetBlock2D( + in_channels=base_in_channels, + out_channels=base_out_channels, + temb_channels=temb_channels, + groups=norm_num_groups, + ) + ) + ctrl_resnets.append( + ResnetBlock2D( + in_channels=ctrl_in_channels + base_in_channels, # information from base is concatted to ctrl + out_channels=ctrl_out_channels, + temb_channels=temb_channels, + groups=find_largest_factor( + ctrl_in_channels + base_in_channels, max_factor=ctrl_max_norm_num_groups + ), + groups_out=find_largest_factor(ctrl_out_channels, max_factor=ctrl_max_norm_num_groups), + eps=1e-5, + ) + ) + + if has_crossattn: + base_attentions.append( + Transformer2DModel( + base_num_attention_heads, + base_out_channels // base_num_attention_heads, + in_channels=base_out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + use_linear_projection=True, + upcast_attention=upcast_attention, + norm_num_groups=norm_num_groups, + ) + ) + ctrl_attentions.append( + Transformer2DModel( + ctrl_num_attention_heads, + ctrl_out_channels // ctrl_num_attention_heads, + in_channels=ctrl_out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + use_linear_projection=True, + upcast_attention=upcast_attention, + norm_num_groups=find_largest_factor(ctrl_out_channels, max_factor=ctrl_max_norm_num_groups), + ) + ) + + # After the resnet/attention application, information is added from control to base + # Addition requires change in number of channels + ctrl_to_base.append(make_zero_conv(ctrl_out_channels, base_out_channels)) + + if add_downsample: + # Before the downsampler application, information is concatted from base to control + # Concat doesn't require change in number of channels + base_to_ctrl.append(make_zero_conv(base_out_channels, base_out_channels)) + + self.base_downsamplers = Downsample2D( + base_out_channels, use_conv=True, out_channels=base_out_channels, name="op" + ) + self.ctrl_downsamplers = Downsample2D( + ctrl_out_channels + base_out_channels, use_conv=True, out_channels=ctrl_out_channels, name="op" + ) + + # After the downsampler application, information is added from control to base + # Addition requires change in number of channels + ctrl_to_base.append(make_zero_conv(ctrl_out_channels, base_out_channels)) + else: + self.base_downsamplers = None + self.ctrl_downsamplers = None + + self.base_resnets = nn.CellList(base_resnets) + self.ctrl_resnets = nn.CellList(ctrl_resnets) + self.base_attentions = nn.CellList(base_attentions) if has_crossattn else [None] * num_layers + self.ctrl_attentions = nn.CellList(ctrl_attentions) if has_crossattn else [None] * num_layers + self.base_to_ctrl = nn.CellList(base_to_ctrl) + self.ctrl_to_base = nn.CellList(ctrl_to_base) + + self.gradient_checkpointing = False + + @classmethod + def from_modules(cls, base_downblock: CrossAttnDownBlock2D, ctrl_downblock: DownBlockControlNetXSAdapter): + # get params + def get_first_cross_attention(block): + return block.attentions[0].transformer_blocks[0].attn2 + + base_in_channels = base_downblock.resnets[0].in_channels + base_out_channels = base_downblock.resnets[0].out_channels + ctrl_in_channels = ( + ctrl_downblock.resnets[0].in_channels - base_in_channels + ) # base channels are concatted to ctrl channels in init + ctrl_out_channels = ctrl_downblock.resnets[0].out_channels + temb_channels = base_downblock.resnets[0].time_emb_proj.in_channels + num_groups = base_downblock.resnets[0].norm1.num_groups + ctrl_num_groups = ctrl_downblock.resnets[0].norm1.num_groups + if hasattr(base_downblock, "attentions"): + has_crossattn = True + transformer_layers_per_block = len(base_downblock.attentions[0].transformer_blocks) + base_num_attention_heads = get_first_cross_attention(base_downblock).heads + ctrl_num_attention_heads = get_first_cross_attention(ctrl_downblock).heads + cross_attention_dim = get_first_cross_attention(base_downblock).cross_attention_dim + upcast_attention = get_first_cross_attention(base_downblock).upcast_attention + else: + has_crossattn = False + transformer_layers_per_block = None + base_num_attention_heads = None + ctrl_num_attention_heads = None + cross_attention_dim = None + upcast_attention = None + add_downsample = base_downblock.downsamplers is not None + + # create model + model = cls( + base_in_channels=base_in_channels, + base_out_channels=base_out_channels, + ctrl_in_channels=ctrl_in_channels, + ctrl_out_channels=ctrl_out_channels, + temb_channels=temb_channels, + norm_num_groups=num_groups, + ctrl_max_norm_num_groups=ctrl_num_groups, + has_crossattn=has_crossattn, + transformer_layers_per_block=transformer_layers_per_block, + base_num_attention_heads=base_num_attention_heads, + ctrl_num_attention_heads=ctrl_num_attention_heads, + cross_attention_dim=cross_attention_dim, + add_downsample=add_downsample, + upcast_attention=upcast_attention, + ) + sync_dtype(model, base_downblock) + + # load weights + ms.load_param_into_net(model.base_resnets, get_state_dict(base_downblock.resnets, "base_resnets")) + ms.load_param_into_net(model.ctrl_resnets, get_state_dict(ctrl_downblock.resnets, "ctrl_resnets")) + if has_crossattn: + ms.load_param_into_net(model.base_attentions, get_state_dict(base_downblock.attentions, "base_attentions")) + ms.load_param_into_net(model.ctrl_attentions, get_state_dict(ctrl_downblock.attentions, "ctrl_attentions")) + if add_downsample: + ms.load_param_into_net( + model.base_downsamplers, get_state_dict(base_downblock.downsamplers[0], "base_downsamplers") + ) + ms.load_param_into_net( + model.ctrl_downsamplers, get_state_dict(ctrl_downblock.downsamplers, "ctrl_downsamplers") + ) + ms.load_param_into_net(model.base_to_ctrl, get_state_dict(ctrl_downblock.base_to_ctrl, "base_to_ctrl")) + ms.load_param_into_net(model.ctrl_to_base, get_state_dict(ctrl_downblock.ctrl_to_base, "ctrl_to_base")) + + return model + + def freeze_base_params(self) -> None: + """Freeze the weights of the parts belonging to the base UNet2DConditionModel, and leave everything else unfrozen for fine + tuning.""" + # Unfreeze everything + for param in self.get_parameters(): + param.requires_grad = True + + # Freeze base part + base_parts = [self.base_resnets] + if isinstance(self.base_attentions, nn.CellList): # attentions can be a list of Nones + base_parts.append(self.base_attentions) + if self.base_downsamplers is not None: + base_parts.append(self.base_downsamplers) + for part in base_parts: + for param in part.get_parameters(): + param.requires_grad = False + + def construct( + self, + hidden_states_base: ms.Tensor, + temb: ms.Tensor, + encoder_hidden_states: Optional[ms.Tensor] = None, + hidden_states_ctrl: Optional[ms.Tensor] = None, + conditioning_scale: Optional[float] = 1.0, + attention_mask: Optional[ms.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[ms.Tensor] = None, + apply_control: bool = True, + ) -> Tuple[ms.Tensor, ms.Tensor, Tuple[ms.Tensor, ...], Tuple[ms.Tensor, ...]]: + if cross_attention_kwargs is not None and cross_attention_kwargs.get("scale", None) is not None: + raise RuntimeError( + f"You are trying to set scaling of lora layer by passing {cross_attention_kwargs['scale']}. " + f"However it's not allowed in on-the-fly model forwarding. " + f"Please manually call `scale_lora_layers(model, lora_scale)` before model forwarding and " + f"`unscale_lora_layers(model, lora_scale)` after model forwarding. " + f"For example, it can be done in a pipeline call like `StableDiffusionPipeline.__call__`." + ) + # breakpoint() + h_base = hidden_states_base + h_ctrl = hidden_states_ctrl + + base_output_states = () + ctrl_output_states = () + + base_blocks = list(zip(self.base_resnets, self.base_attentions)) + ctrl_blocks = list(zip(self.ctrl_resnets, self.ctrl_attentions)) + + for (b_res, b_attn), (c_res, c_attn), b2c, c2b in zip( + base_blocks, ctrl_blocks, self.base_to_ctrl, self.ctrl_to_base + ): + # concat base -> ctrl + if apply_control: + h_ctrl = ops.cat([h_ctrl, b2c(h_base)], axis=1) + + # apply base subblock + h_base = b_res(h_base, temb) + + if b_attn is not None: + h_base = b_attn( + h_base, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + + # apply ctrl subblock + if apply_control: + h_ctrl = c_res(h_ctrl, temb) + if c_attn is not None: + h_ctrl = c_attn( + h_ctrl, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + + # add ctrl -> base + if apply_control: + h_base = h_base + c2b(h_ctrl) * conditioning_scale + + base_output_states = base_output_states + (h_base,) + ctrl_output_states = ctrl_output_states + (h_ctrl,) + + if self.base_downsamplers is not None: # if we have a base_downsampler, then also a ctrl_downsampler + b2c = self.base_to_ctrl[-1] + c2b = self.ctrl_to_base[-1] + + # concat base -> ctrl + if apply_control: + h_ctrl = ops.cat([h_ctrl, b2c(h_base)], axis=1) + # apply base subblock + h_base = self.base_downsamplers(h_base) + # apply ctrl subblock + if apply_control: + h_ctrl = self.ctrl_downsamplers(h_ctrl) + # add ctrl -> base + if apply_control: + h_base = h_base + c2b(h_ctrl) * conditioning_scale + + base_output_states = base_output_states + (h_base,) + ctrl_output_states = ctrl_output_states + (h_ctrl,) + + return h_base, h_ctrl, base_output_states, ctrl_output_states + + +class ControlNetXSCrossAttnMidBlock2D(nn.Cell): + def __init__( + self, + base_channels: int, + ctrl_channels: int, + temb_channels: Optional[int] = None, + norm_num_groups: int = 32, + ctrl_max_norm_num_groups: int = 32, + transformer_layers_per_block: int = 1, + base_num_attention_heads: Optional[int] = 1, + ctrl_num_attention_heads: Optional[int] = 1, + cross_attention_dim: Optional[int] = 1024, + upcast_attention: bool = False, + ): + super().__init__() + + # Before the midblock application, information is concatted from base to control. + # Concat doesn't require change in number of channels + self.base_to_ctrl = make_zero_conv(base_channels, base_channels) + + self.base_midblock = UNetMidBlock2DCrossAttn( + transformer_layers_per_block=transformer_layers_per_block, + in_channels=base_channels, + temb_channels=temb_channels, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=base_num_attention_heads, + use_linear_projection=True, + upcast_attention=upcast_attention, + ) + + self.ctrl_midblock = UNetMidBlock2DCrossAttn( + transformer_layers_per_block=transformer_layers_per_block, + in_channels=ctrl_channels + base_channels, + out_channels=ctrl_channels, + temb_channels=temb_channels, + # number or norm groups must divide both in_channels and out_channels + resnet_groups=find_largest_factor( + gcd(ctrl_channels, ctrl_channels + base_channels), ctrl_max_norm_num_groups + ), + cross_attention_dim=cross_attention_dim, + num_attention_heads=ctrl_num_attention_heads, + use_linear_projection=True, + upcast_attention=upcast_attention, + ) + + # After the midblock application, information is added from control to base + # Addition requires change in number of channels + self.ctrl_to_base = make_zero_conv(ctrl_channels, base_channels) + + self.gradient_checkpointing = False + + @classmethod + def from_modules( + cls, + base_midblock: UNetMidBlock2DCrossAttn, + ctrl_midblock: MidBlockControlNetXSAdapter, + ): + base_to_ctrl = ctrl_midblock.base_to_ctrl + ctrl_to_base = ctrl_midblock.ctrl_to_base + ctrl_midblock = ctrl_midblock.midblock + + # get params + def get_first_cross_attention(midblock): + return midblock.attentions[0].transformer_blocks[0].attn2 + + base_channels = ctrl_to_base.out_channels + ctrl_channels = ctrl_to_base.in_channels + transformer_layers_per_block = len(base_midblock.attentions[0].transformer_blocks) + temb_channels = base_midblock.resnets[0].time_emb_proj.in_channels + num_groups = base_midblock.resnets[0].norm1.num_groups + ctrl_num_groups = ctrl_midblock.resnets[0].norm1.num_groups + base_num_attention_heads = get_first_cross_attention(base_midblock).heads + ctrl_num_attention_heads = get_first_cross_attention(ctrl_midblock).heads + cross_attention_dim = get_first_cross_attention(base_midblock).cross_attention_dim + upcast_attention = get_first_cross_attention(base_midblock).upcast_attention + + # create model + model = cls( + base_channels=base_channels, + ctrl_channels=ctrl_channels, + temb_channels=temb_channels, + norm_num_groups=num_groups, + ctrl_max_norm_num_groups=ctrl_num_groups, + transformer_layers_per_block=transformer_layers_per_block, + base_num_attention_heads=base_num_attention_heads, + ctrl_num_attention_heads=ctrl_num_attention_heads, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + ) + sync_dtype(model, base_midblock) + + # load weights + ms.load_param_into_net(model.base_to_ctrl, get_state_dict(base_to_ctrl, "base_to_ctrl")) + ms.load_param_into_net(model.base_midblock, get_state_dict(base_midblock, "base_midblock")) + ms.load_param_into_net(model.ctrl_midblock, get_state_dict(ctrl_midblock, "ctrl_midblock")) + ms.load_param_into_net(model.ctrl_to_base, get_state_dict(ctrl_to_base, "ctrl_to_base")) + + return model + + def freeze_base_params(self) -> None: + """Freeze the weights of the parts belonging to the base UNet2DConditionModel, and leave everything else unfrozen for fine + tuning.""" + # Unfreeze everything + for param in self.get_parameters(): + param.requires_grad = True + + # Freeze base part + for param in self.base_midblock.get_parameters(): + param.requires_grad = False + + def construct( + self, + hidden_states_base: ms.Tensor, + temb: ms.Tensor, + encoder_hidden_states: ms.Tensor, + hidden_states_ctrl: Optional[ms.Tensor] = None, + conditioning_scale: Optional[float] = 1.0, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + attention_mask: Optional[ms.Tensor] = None, + encoder_attention_mask: Optional[ms.Tensor] = None, + apply_control: bool = True, + ) -> Tuple[ms.Tensor, ms.Tensor]: + if cross_attention_kwargs is not None and cross_attention_kwargs.get("scale", None) is not None: + raise RuntimeError( + f"You are trying to set scaling of lora layer by passing {cross_attention_kwargs['scale']=}. " + f"However it's not allowed in on-the-fly model forwarding. " + f"Please manually call `scale_lora_layers(model, lora_scale)` before model forwarding and " + f"`unscale_lora_layers(model, lora_scale)` after model forwarding. " + f"For example, it can be done in a pipeline call like `StableDiffusionPipeline.__call__`." + ) + + h_base = hidden_states_base + h_ctrl = hidden_states_ctrl + + joint_args = { + "temb": temb, + "encoder_hidden_states": encoder_hidden_states, + "attention_mask": attention_mask, + "cross_attention_kwargs": cross_attention_kwargs, + "encoder_attention_mask": encoder_attention_mask, + } + + if apply_control: + h_ctrl = ops.cat([h_ctrl, self.base_to_ctrl(h_base)], axis=1) # concat base -> ctrl + h_base = self.base_midblock(h_base, **joint_args) # apply base mid block + if apply_control: + h_ctrl = self.ctrl_midblock(h_ctrl, **joint_args) # apply ctrl mid block + h_base = h_base + self.ctrl_to_base(h_ctrl) * conditioning_scale # add ctrl -> base + + return h_base, h_ctrl + + +class ControlNetXSCrossAttnUpBlock2D(nn.Cell): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + ctrl_skip_channels: List[int], + temb_channels: int, + norm_num_groups: int = 32, + resolution_idx: Optional[int] = None, + has_crossattn=True, + transformer_layers_per_block: int = 1, + num_attention_heads: int = 1, + cross_attention_dim: int = 1024, + add_upsample: bool = True, + upcast_attention: bool = False, + ): + super().__init__() + resnets = [] + attentions = [] + ctrl_to_base = [] + + num_layers = 3 # only support sd + sdxl + self.num_layers = num_layers + + self.has_cross_attention = has_crossattn + self.num_attention_heads = num_attention_heads + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + ctrl_to_base.append(make_zero_conv(ctrl_skip_channels[i], resnet_in_channels)) + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + groups=norm_num_groups, + ) + ) + + if has_crossattn: + attentions.append( + Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + use_linear_projection=True, + upcast_attention=upcast_attention, + norm_num_groups=norm_num_groups, + ) + ) + + self.resnets = nn.CellList(resnets) + self.attentions = nn.CellList(attentions) if has_crossattn else [None] * num_layers + self.ctrl_to_base = nn.CellList(ctrl_to_base) + + if add_upsample: + self.upsamplers = Upsample2D(out_channels, use_conv=True, out_channels=out_channels) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + @classmethod + def from_modules(cls, base_upblock: CrossAttnUpBlock2D, ctrl_upblock: UpBlockControlNetXSAdapter): + ctrl_to_base_skip_connections = ctrl_upblock.ctrl_to_base + + # get params + def get_first_cross_attention(block): + return block.attentions[0].transformer_blocks[0].attn2 + + out_channels = base_upblock.resnets[0].out_channels + in_channels = base_upblock.resnets[-1].in_channels - out_channels + prev_output_channels = base_upblock.resnets[0].in_channels - out_channels + ctrl_skip_channelss = [c.in_channels for c in ctrl_to_base_skip_connections] + temb_channels = base_upblock.resnets[0].time_emb_proj.in_channels + num_groups = base_upblock.resnets[0].norm1.num_groups + resolution_idx = base_upblock.resolution_idx + if hasattr(base_upblock, "attentions"): + has_crossattn = True + transformer_layers_per_block = len(base_upblock.attentions[0].transformer_blocks) + num_attention_heads = get_first_cross_attention(base_upblock).heads + cross_attention_dim = get_first_cross_attention(base_upblock).cross_attention_dim + upcast_attention = get_first_cross_attention(base_upblock).upcast_attention + else: + has_crossattn = False + transformer_layers_per_block = None + num_attention_heads = None + cross_attention_dim = None + upcast_attention = None + add_upsample = base_upblock.upsamplers is not None + + # create model + model = cls( + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channels, + ctrl_skip_channels=ctrl_skip_channelss, + temb_channels=temb_channels, + norm_num_groups=num_groups, + resolution_idx=resolution_idx, + has_crossattn=has_crossattn, + transformer_layers_per_block=transformer_layers_per_block, + num_attention_heads=num_attention_heads, + cross_attention_dim=cross_attention_dim, + add_upsample=add_upsample, + upcast_attention=upcast_attention, + ) + sync_dtype(model, base_upblock) + + # load weights + ms.load_param_into_net(model.resnets, get_state_dict(base_upblock.resnets, "resnets")) + if has_crossattn: + ms.load_param_into_net(model.attentions, get_state_dict(base_upblock.attentions, "attentions")) + if add_upsample: + ms.load_param_into_net(model.upsamplers, get_state_dict(base_upblock.upsamplers[0], "upsamplers")) + ms.load_param_into_net(model.ctrl_to_base, get_state_dict(ctrl_to_base_skip_connections, "ctrl_to_base")) + + return model + + def freeze_base_params(self) -> None: + """Freeze the weights of the parts belonging to the base UNet2DConditionModel, and leave everything else unfrozen for fine + tuning.""" + # Unfreeze everything + for param in self.get_parameters(): + param.requires_grad = True + + # Freeze base part + base_parts = [self.resnets] + if isinstance(self.attentions, nn.CellList): # attentions can be a list of Nones + base_parts.append(self.attentions) + if self.upsamplers is not None: + base_parts.append(self.upsamplers) + for part in base_parts: + for param in part.get_parameters(): + param.requires_grad = False + + def construct( + self, + hidden_states: ms.Tensor, + res_hidden_states_tuple_base: Tuple[ms.Tensor, ...], + res_hidden_states_tuple_ctrl: Tuple[ms.Tensor, ...], + temb: ms.Tensor, + encoder_hidden_states: Optional[ms.Tensor] = None, + conditioning_scale: Optional[float] = 1.0, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + attention_mask: Optional[ms.Tensor] = None, + upsample_size: Optional[int] = None, + encoder_attention_mask: Optional[ms.Tensor] = None, + apply_control: bool = True, + ) -> ms.Tensor: + if cross_attention_kwargs is not None and cross_attention_kwargs.get("scale", None) is not None: + raise RuntimeError( + f"You are trying to set scaling of lora layer by passing {cross_attention_kwargs['scale']=}. " + f"However it's not allowed in on-the-fly model forwarding. " + f"Please manually call `scale_lora_layers(model, lora_scale)` before model forwarding and " + f"`unscale_lora_layers(model, lora_scale)` after model forwarding. " + f"For example, it can be done in a pipeline call like `StableDiffusionPipeline.__call__`." + ) + + for i in range(self.num_layers): + # reversed operation not fully supported in GRAPH MODE, use index to get block instead. + resnet, attn, c2b, res_h_base, res_h_ctrl = ( + self.resnets[i], + self.attentions[i], + self.ctrl_to_base[i], + res_hidden_states_tuple_base[self.num_layers - i - 1], + res_hidden_states_tuple_ctrl[self.num_layers - i - 1], + ) + if apply_control: + hidden_states += c2b(res_h_ctrl) * conditioning_scale + + hidden_states = ops.cat([hidden_states, res_h_base], axis=1) + hidden_states = resnet(hidden_states, temb) + + if attn is not None: + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + + if self.upsamplers is not None: + hidden_states = self.upsamplers(hidden_states, upsample_size) + + return hidden_states + + +def make_zero_conv(in_channels, out_channels=None): + return nn.Conv2d( + in_channels, out_channels, 1, pad_mode="pad", padding=0, has_bias=True, weight_init="zeros", bias_init="zeros" + ) + + +def find_largest_factor(number, max_factor): + factor = max_factor + if factor >= number: + return number + while factor != 0: + residual = number % factor + if residual == 0: + return factor + factor -= 1 + + +# ============================= +# Paramater processing +# ============================= +def get_state_dict(module: nn.Cell, name_prefix="", recurse=True): + """ + A function attempting to achieve an effect similar to torch's `nn.Module.state_dict()`. + + Due to MindSpore's unique parameter naming mechanism, this function performs operations + on the prefix of parameter names. This ensures that parameters can be correctly loaded + using `mindspore.load_param_into_net()` when there are discrepancies between the parameter + names of the target_model and source_model. + """ + param_generator = module.parameters_and_names(name_prefix=name_prefix, expand=recurse) + + param_dict = OrderedDict() + for name, param in param_generator: + param.name = name + param_dict[name] = param + return param_dict + + +def sync_dtype(module: nn.Cell, other: nn.Cell): + """ + Sets the dtype of 'module' to match that of 'other'. + + This function is designed to prevent warnings arising from data type mismatches when later + synchronizing parameters between two models. In MindSpore, such mismatches can lead to + numerous warning messages, whereas PyTorch tends to handle these situations more silently. + """ + _, first_p = next(other.parameters_and_names()) + target_dtype = first_p.dtype + + for p in module.get_parameters(): + p.set_dtype(target_dtype) + + return module diff --git a/mindone/diffusers/models/downsampling.py b/mindone/diffusers/models/downsampling.py index f5d268042e..b92d988fcb 100644 --- a/mindone/diffusers/models/downsampling.py +++ b/mindone/diffusers/models/downsampling.py @@ -107,7 +107,6 @@ def __init__( self.padding = padding stride = 2 self.name = name - conv_cls = nn.Conv2d if norm_type == "ln_norm": self.norm = LayerNorm(channels, eps, elementwise_affine) @@ -119,7 +118,7 @@ def __init__( raise ValueError(f"unknown norm_type: {norm_type}") if use_conv: - conv = conv_cls( + conv = nn.Conv2d( self.channels, self.out_channels, kernel_size=kernel_size, diff --git a/mindone/diffusers/models/embeddings.py b/mindone/diffusers/models/embeddings.py index 36c0ff31fb..bb74c4ed2f 100644 --- a/mindone/diffusers/models/embeddings.py +++ b/mindone/diffusers/models/embeddings.py @@ -135,6 +135,7 @@ def __init__( interpolation_scale=1, pos_embed_type="sincos", pos_embed_max_size=None, # For SD3 cropping + zero_module=False, # For SD3 ControlNet ): super().__init__() from .normalization import LayerNorm @@ -144,6 +145,7 @@ def __init__( self.layer_norm = layer_norm self.pos_embed_max_size = pos_embed_max_size + weight_init_kwargs = {"weight_init": "zeros", "bias_init": "zeros"} if zero_module else {} self.proj = nn.Conv2d( in_channels, embed_dim, @@ -151,6 +153,7 @@ def __init__( stride=patch_size, pad_mode="pad", has_bias=bias, + **weight_init_kwargs, ) if layer_norm: self.norm = LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6) @@ -232,6 +235,112 @@ def construct(self, latent): return (latent + pos_embed).to(latent.dtype) +def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True): + """ + RoPE for image tokens with 2d structure. + + Args: + embed_dim: (`int`): + The embedding dimension size + crops_coords (`Tuple[int]`) + The top-left and bottom-right coordinates of the crop. + grid_size (`Tuple[int]`): + The grid size of the positional embedding. + use_real (`bool`): + If True, return real part and imaginary part separately. Otherwise, return complex numbers. + + Returns: + `ms.Tensor`: positional embdding with shape `( grid_size * grid_size, embed_dim/2)`. + """ + start, stop = crops_coords + grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32) + grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) # [2, W, H] + + grid = grid.reshape([2, 1, *grid.shape[1:]]) + pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real) + return pos_embed + + +def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False): + assert embed_dim % 4 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_rotary_pos_embed(embed_dim // 2, grid[0].reshape(-1), use_real=use_real) # (H*W, D/4) + emb_w = get_1d_rotary_pos_embed(embed_dim // 2, grid[1].reshape(-1), use_real=use_real) # (H*W, D/4) + + if use_real: + cos = ops.cat([emb_h[0], emb_w[0]], axis=1) # (H*W, D/2) + sin = ops.cat([emb_h[1], emb_w[1]], axis=1) # (H*W, D/2) + return cos, sin + else: + emb = ops.cat([emb_h, emb_w], axis=1) # (H*W, D/2) + return emb + + +def get_1d_rotary_pos_embed(dim: int, pos: Union[np.ndarray, int], theta: float = 10000.0, use_real=False): + """ + Precompute the frequency tensor for complex exponentials (cis) with given dimensions. + + This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end + index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64 + data type. + + Args: + dim (`int`): Dimension of the frequency tensor. + pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar + theta (`float`, *optional*, defaults to 10000.0): + Scaling factor for frequency computation. Defaults to 10000.0. + use_real (`bool`, *optional*): + If True, return real part and imaginary part separately. Otherwise, return complex numbers. + + Returns: + `ms.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2] + """ + if isinstance(pos, int): + pos = np.arange(pos) + freqs = 1.0 / (theta ** (ops.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2] + t = ms.Tensor.from_numpy(pos) # type: ignore # [S] + freqs = ops.outer(t, freqs).float() # type: ignore # [S, D/2] + if use_real: + freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D] + freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D] + return freqs_cos, freqs_sin + else: + freqs_cis = ops.polar(ops.ones_like(freqs), freqs) # complex64 # [S, D/2] + return freqs_cis + + +def apply_rotary_emb( + x: ms.Tensor, + freqs_cis: Union[ms.Tensor, Tuple[ms.Tensor]], +) -> Tuple[ms.Tensor, ms.Tensor]: + """ + Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings + to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are + reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting + tensors contain rotary embeddings and are returned as real tensors. + + Args: + x (`ms.Tensor`): + Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (ms.Tensor): Key tensor to apply + freqs_cis (`Tuple[ms.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],) + + Returns: + Tuple[ms.Tensor, ms.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + """ + cos, sin = freqs_cis # [S, D] + cos = cos[None, None] + sin = sin[None, None] + + x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] + x_rotated = ops.stack([-x_imag, x_real], axis=-1).flatten(start_dim=3) + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + + return out + + class TimestepEmbedding(nn.Cell): def __init__( self, @@ -244,9 +353,8 @@ def __init__( sample_proj_bias=True, ): super().__init__() - linear_cls = nn.Dense - self.linear_1 = linear_cls(in_channels, time_embed_dim, has_bias=sample_proj_bias) + self.linear_1 = nn.Dense(in_channels, time_embed_dim, has_bias=sample_proj_bias) if cond_proj_dim is not None: self.cond_proj = nn.Dense(cond_proj_dim, in_channels, has_bias=False) @@ -259,7 +367,7 @@ def __init__( time_embed_dim_out = out_dim else: time_embed_dim_out = time_embed_dim - self.linear_2 = linear_cls(time_embed_dim, time_embed_dim_out, has_bias=sample_proj_bias) + self.linear_2 = nn.Dense(time_embed_dim, time_embed_dim_out, has_bias=sample_proj_bias) if post_act_fn is None: self.post_act = None @@ -520,6 +628,23 @@ def construct(self, image_embeds: ms.Tensor): return self.norm(self.ff(image_embeds)) +class IPAdapterFaceIDImageProjection(nn.Cell): + def __init__(self, image_embed_dim=1024, cross_attention_dim=1024, mult=1, num_tokens=1): + super().__init__() + from .attention import FeedForward + from .normalization import LayerNorm + + self.num_tokens = num_tokens + self.cross_attention_dim = cross_attention_dim + self.ff = FeedForward(image_embed_dim, cross_attention_dim * num_tokens, mult=mult, activation_fn="gelu") + self.norm = LayerNorm(cross_attention_dim) + + def construct(self, image_embeds: ms.Tensor): + x = self.ff(image_embeds) + x = x.reshape(-1, self.num_tokens, self.cross_attention_dim) + return self.norm(x) + + class CombinedTimestepLabelEmbeddings(nn.Cell): def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1): super().__init__() @@ -559,6 +684,91 @@ def construct(self, timestep, pooled_projection): return conditioning +class HunyuanDiTAttentionPool(nn.Cell): + # Copied from https://github.com/Tencent/HunyuanDiT/blob/cb709308d92e6c7e8d59d0dff41b74d35088db6a/hydit/modules/poolers.py#L6 + + def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): + super().__init__() + self.positional_embedding = ms.Parameter( + ops.randn(spacial_dim + 1, embed_dim) / embed_dim**0.5, name="positional_embedding" + ) + self.k_proj = nn.Dense(embed_dim, embed_dim) + self.q_proj = nn.Dense(embed_dim, embed_dim) + self.v_proj = nn.Dense(embed_dim, embed_dim) + self.c_proj = nn.Dense(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def construct(self, x: ms.Tensor): + x = x.permute(1, 0, 2) # NLC -> LNC + x = ops.cat([x.mean(axis=0, keep_dims=True), x], axis=0) # (L+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (L+1)NC + x, _ = ops.function.nn_func.multi_head_attention_forward( + query=x[:1], + key=x, + value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=ops.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0.0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + dtype=x.dtype, # mindspore must specify argument dtype, otherwise fp32 will be used + ) + return x.squeeze(0) + + +class HunyuanCombinedTimestepTextSizeStyleEmbedding(nn.Cell): + def __init__(self, embedding_dim, pooled_projection_dim=1024, seq_len=256, cross_attention_dim=2048): + super().__init__() + + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + + self.pooler = HunyuanDiTAttentionPool( + seq_len, cross_attention_dim, num_heads=8, output_dim=pooled_projection_dim + ) + # Here we use a default learned embedder layer for future extension. + self.style_embedder = nn.Embedding(1, embedding_dim) + extra_in_dim = 256 * 6 + embedding_dim + pooled_projection_dim + self.extra_embedder = PixArtAlphaTextProjection( + in_features=extra_in_dim, + hidden_size=embedding_dim * 4, + out_features=embedding_dim, + act_fn="silu_fp32", + ) + + def construct(self, timestep, encoder_hidden_states, image_meta_size, style, hidden_dtype=None): + hidden_dtype = hidden_dtype or encoder_hidden_states.dtype # tensor.to(None) is invalid + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, 256) + + # extra condition1: text + pooled_projections = self.pooler(encoder_hidden_states) # (N, 1024) + + # extra condition2: image meta size embdding + image_meta_size = get_timestep_embedding(image_meta_size.view(-1), 256, True, 0) + image_meta_size = image_meta_size.to(dtype=hidden_dtype) + image_meta_size = image_meta_size.view(-1, 6 * 256) # (N, 1536) + + # extra condition3: style embedding + style_embedding = self.style_embedder(style) # (N, embedding_dim) + + # Concatenate all extra vectors + extra_cond = ops.cat([pooled_projections, image_meta_size, style_embedding], axis=1) + conditioning = timesteps_emb + self.extra_embedder(extra_cond) # [B, D] + + return conditioning + + class TextTimeEmbedding(nn.Cell): def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64): super().__init__() @@ -868,7 +1078,7 @@ def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tan self.act_1 = FP32SiLU() else: raise ValueError(f"Unknown activation function: {act_fn}") - self.linear_2 = nn.Dense(in_channels=hidden_size, out_channels=hidden_size, has_bias=True) + self.linear_2 = nn.Dense(in_channels=hidden_size, out_channels=out_features, has_bias=True) def construct(self, caption): hidden_states = self.linear_1(caption) @@ -877,21 +1087,53 @@ def construct(self, caption): return hidden_states +class IPAdapterPlusImageProjectionBlock(nn.Cell): + def __init__( + self, + embed_dims: int = 768, + dim_head: int = 64, + heads: int = 16, + ffn_ratio: float = 4, + ) -> None: + super().__init__() + from .attention import FeedForward + from .normalization import LayerNorm + + self.ln0 = LayerNorm(embed_dims) + self.ln1 = LayerNorm(embed_dims) + self.attn = Attention( + query_dim=embed_dims, + dim_head=dim_head, + heads=heads, + out_bias=False, + ) + self.ff = nn.SequentialCell( + LayerNorm(embed_dims), + FeedForward(embed_dims, embed_dims, activation_fn="gelu", mult=ffn_ratio, bias=False), + ) + + def construct(self, x, latents, residual): + encoder_hidden_states = self.ln0(x) + latents = self.ln1(latents) + encoder_hidden_states = ops.cat([encoder_hidden_states, latents], axis=-2) + latents = self.attn(latents, encoder_hidden_states) + residual + latents = self.ff(latents) + latents + return latents + + class IPAdapterPlusImageProjection(nn.Cell): """Resampler of IP-Adapter Plus. Args: - ---- - embed_dims (int): The feature dimension. Defaults to 768. - output_dims (int): The number of output channels, that is the same - number of the channels in the - `unet.config.cross_attention_dim`. Defaults to 1024. - hidden_dims (int): The number of hidden channels. Defaults to 1280. - depth (int): The number of blocks. Defaults to 8. - dim_head (int): The number of head channels. Defaults to 64. - heads (int): Parallel attention heads. Defaults to 16. - num_queries (int): The number of queries. Defaults to 8. - ffn_ratio (float): The expansion ratio of feedforward network hidden + embed_dims (int): The feature dimension. Defaults to 768. output_dims (int): The number of output channels, + that is the same + number of the channels in the `unet.config.cross_attention_dim`. Defaults to 1024. + hidden_dims (int): + The number of hidden channels. Defaults to 1280. depth (int): The number of blocks. Defaults + to 8. dim_head (int): The number of head channels. Defaults to 64. heads (int): Parallel attention heads. + Defaults to 16. num_queries (int): + The number of queries. Defaults to 8. ffn_ratio (float): The expansion ratio + of feedforward network hidden layer channels. Defaults to 4. """ @@ -907,8 +1149,7 @@ def __init__( ffn_ratio: float = 4, ) -> None: super().__init__() - from .attention import FeedForward - from .normalization import LayerNorm # Lazy import to avoid circular import + from .normalization import LayerNorm self.latents = ms.Parameter(ops.randn(1, num_queries, hidden_dims) / hidden_dims**0.5, name="latents") @@ -917,56 +1158,111 @@ def __init__( self.proj_out = nn.Dense(hidden_dims, output_dims) self.norm_out = LayerNorm(output_dims) - layers = [] - for _ in range(depth): - layers.append( - nn.CellList( - [ - LayerNorm(hidden_dims), - LayerNorm(hidden_dims), - Attention( - query_dim=hidden_dims, - dim_head=dim_head, - heads=heads, - out_bias=False, - ), - nn.SequentialCell( - LayerNorm(hidden_dims), - FeedForward(hidden_dims, hidden_dims, activation_fn="gelu", mult=ffn_ratio, bias=False), - ), - ] - ) - ) - self.layers = nn.CellList(layers) + self.layers = nn.CellList( + [IPAdapterPlusImageProjectionBlock(hidden_dims, dim_head, heads, ffn_ratio) for _ in range(depth)] + ) def construct(self, x: ms.Tensor) -> ms.Tensor: """Forward pass. Args: - ---- x (ms.Tensor): Input Tensor. - Returns: - ------- ms.Tensor: Output Tensor. """ - latents = self.latents.tile((x.size(0), 1, 1)) + latents = self.latents.tile((x.shape[0], 1, 1)) x = self.proj_in(x) - for ln0, ln1, attn, ff in self.layers: + for block in self.layers: residual = latents - - encoder_hidden_states = ln0(x) - latents = ln1(latents) - encoder_hidden_states = ops.cat([encoder_hidden_states, latents], axis=-2) - latents = attn(latents, encoder_hidden_states) + residual - latents = ff(latents) + latents + latents = block(x, latents, residual) latents = self.proj_out(latents) return self.norm_out(latents) +class IPAdapterFaceIDPlusImageProjection(nn.Cell): + """FacePerceiverResampler of IP-Adapter Plus. + + Args: + embed_dims (int): The feature dimension. Defaults to 768. output_dims (int): The number of output channels, + that is the same + number of the channels in the `unet.config.cross_attention_dim`. Defaults to 1024. + hidden_dims (int): + The number of hidden channels. Defaults to 1280. depth (int): The number of blocks. Defaults + to 8. dim_head (int): The number of head channels. Defaults to 64. heads (int): Parallel attention heads. + Defaults to 16. num_tokens (int): Number of tokens num_queries (int): The number of queries. Defaults to 8. + ffn_ratio (float): The expansion ratio of feedforward network hidden + layer channels. Defaults to 4. + ffproj_ratio (float): The expansion ratio of feedforward network hidden + layer channels (for ID embeddings). Defaults to 4. + """ + + def __init__( + self, + embed_dims: int = 768, + output_dims: int = 768, + hidden_dims: int = 1280, + id_embeddings_dim: int = 512, + depth: int = 4, + dim_head: int = 64, + heads: int = 16, + num_tokens: int = 4, + num_queries: int = 8, + ffn_ratio: float = 4, + ffproj_ratio: int = 2, + ) -> None: + super().__init__() + from .attention import FeedForward + from .normalization import LayerNorm + + self.num_tokens = num_tokens + self.embed_dim = embed_dims + self.clip_embeds = None + self.shortcut = False + self.shortcut_scale = 1.0 + + self.proj = FeedForward(id_embeddings_dim, embed_dims * num_tokens, activation_fn="gelu", mult=ffproj_ratio) + self.norm = LayerNorm(embed_dims) + + self.proj_in = nn.Dense(hidden_dims, embed_dims) + + self.proj_out = nn.Dense(embed_dims, output_dims) + self.norm_out = LayerNorm(output_dims) + + self.layers = nn.CellList( + [IPAdapterPlusImageProjectionBlock(embed_dims, dim_head, heads, ffn_ratio) for _ in range(depth)] + ) + + def construct(self, id_embeds: ms.Tensor) -> ms.Tensor: + """Forward pass. + + Args: + id_embeds (ms.Tensor): Input Tensor (ID embeds). + Returns: + ms.Tensor: Output Tensor. + """ + id_embeds = id_embeds.to(self.clip_embeds.dtype) + id_embeds = self.proj(id_embeds) + id_embeds = id_embeds.reshape(-1, self.num_tokens, self.embed_dim) + id_embeds = self.norm(id_embeds) + latents = id_embeds + + clip_embeds = self.proj_in(self.clip_embeds) + x = clip_embeds.reshape(-1, clip_embeds.shape[2], clip_embeds.shape[3]) + + for block in self.layers: + residual = latents + latents = block(x, latents, residual) + + latents = self.proj_out(latents) + out = self.norm_out(latents) + if self.shortcut: + out = id_embeds + self.shortcut_scale * out + return out + + class MultiIPAdapterImageProjection(nn.Cell): def __init__(self, IPAdapterImageProjectionLayers: Union[List[nn.Cell], Tuple[nn.Cell]]): super().__init__() diff --git a/mindone/diffusers/models/model_loading_utils.py b/mindone/diffusers/models/model_loading_utils.py new file mode 100644 index 0000000000..2963633eda --- /dev/null +++ b/mindone/diffusers/models/model_loading_utils.py @@ -0,0 +1,599 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import importlib +import json +import os +from collections import OrderedDict +from contextlib import contextmanager, nullcontext +from dataclasses import dataclass, field +from functools import lru_cache +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, TypeVar, Union + +from huggingface_hub.utils import EntryNotFoundError + +import mindspore as ms +from mindspore import nn + +from ...safetensors.mindspore import load_file as safe_load_file +from ..utils import ( + SAFE_WEIGHTS_INDEX_NAME, + SAFETENSORS_FILE_EXTENSION, + WEIGHTS_INDEX_NAME, + _add_variant, + _get_model_file, + logging, +) + +logger = logging.get_logger(__name__) + +_CLASS_REMAPPING_DICT = { + "Transformer2DModel": { + "ada_norm_zero": "DiTTransformer2DModel", + "ada_norm_single": "PixArtTransformer2DModel", + } +} + + +def _fetch_remapped_cls_from_config(config, old_class): + previous_class_name = old_class.__name__ + remapped_class_name = _CLASS_REMAPPING_DICT.get(previous_class_name).get(config["norm_type"], None) + + # Details: + # https://github.com/huggingface/diffusers/pull/7647#discussion_r1621344818 + if remapped_class_name: + # load diffusers library to import compatible and original scheduler + diffusers_library = importlib.import_module(__name__.split(".")[0] + ".diffusers") + remapped_class = getattr(diffusers_library, remapped_class_name) + logger.info( + f"Changing class object to be of `{remapped_class_name}` type from `{previous_class_name}` type." + f"This is because `{previous_class_name}` is scheduled to be deprecated in a future version. Note that this" + " DOESN'T affect the final results." + ) + return remapped_class + else: + return old_class + + +def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None): + """ + Reads a checkpoint file, returning properly formatted errors if they arise. + """ + try: + file_extension = os.path.basename(checkpoint_file).split(".")[-1] + if file_extension == SAFETENSORS_FILE_EXTENSION: + return safe_load_file(checkpoint_file) + else: + raise NotImplementedError( + f"Only supports deserialization of weights file in safetensors format, but got {checkpoint_file}" + ) + except Exception as e: + try: + with open(checkpoint_file) as f: + if f.read().startswith("version"): + raise OSError( + "You seem to have cloned a repository without having git-lfs installed. Please install " + "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder " + "you cloned." + ) + else: + raise ValueError( + f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained " + "model. Make sure you have saved the model properly." + ) from e + except (UnicodeDecodeError, ValueError): + raise OSError( + f"Unable to load weights from checkpoint file for '{checkpoint_file}' " f"at '{checkpoint_file}'. " + ) + + +def _load_state_dict_into_model(model_to_load, state_dict: OrderedDict) -> List[str]: + # TODO: error_msgs is always empty for now. Maybe we need to rewrite MindSpore's `load_param_into_net`. + # Error msgs should contain caught exception like size mismatch instead of missing/unexpected keys. + # TODO: We should support loading float16 state_dict into float32 model, like PyTorch's behavior. + error_msgs = [] + # TODO: State dict loading in mindspore does not cast dtype correctly. We do it manually. It's might unsafe. + local_state = {k: v for k, v in model_to_load.parameters_and_names()} + for k, v in state_dict.items(): + if k in local_state: + v.set_dtype(local_state[k].dtype) + else: + pass # unexpect key keeps origin dtype + ms.load_param_into_net(model_to_load, state_dict, strict_load=True) + return error_msgs + + +def _fetch_index_file( + is_local, + pretrained_model_name_or_path, + subfolder, + use_safetensors, + cache_dir, + variant, + force_download, + resume_download, + proxies, + local_files_only, + token, + revision, + user_agent, + commit_hash, +): + if is_local: + index_file = Path( + pretrained_model_name_or_path, + subfolder or "", + _add_variant(SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME, variant), + ) + else: + index_file_in_repo = Path( + subfolder or "", + _add_variant(SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME, variant), + ).as_posix() + try: + index_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=index_file_in_repo, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + commit_hash=commit_hash, + ) + index_file = Path(index_file) + except (EntryNotFoundError, EnvironmentError): + index_file = None + + return index_file + + +# =============================================== +# Sharded loading utils by huggingface Accelerate +# =============================================== + +_SAFE_MODEL_NAME = "model" +_SAFE_WEIGHTS_NAME = f"{_SAFE_MODEL_NAME}.safetensors" + + +# Copied from mindone.transformers.modeling_utils.silence_mindspore_logger +@contextmanager +def silence_mindspore_logger(): + ms_logger = ms.log._get_logger() + ms_level = ms_logger.level + ms_logger.setLevel("ERROR") + yield + ms_logger.setLevel(ms_level) + + +def load_checkpoint_and_dispatch( + model: nn.Cell, + checkpoint: Union[str, os.PathLike], + dtype: Optional[Union[str, ms.Type]] = None, + strict: bool = False, +): + """ + Loads a (potentially sharded) checkpoint inside a model, potentially sending weights to a given device as they are + loaded and adds the various hooks that will make this model run properly (even if split across devices). + + Args: + model (`mindspore.nn.Cell`): The model in which we want to load a checkpoint. + checkpoint (`str` or `os.PathLike`): + The folder checkpoint to load. It can be: + - a path to a file containing a whole model state dict + - a path to a `.json` file containing the index to a sharded checkpoint + - a path to a folder containing a unique `.index.json` file and the shards of a checkpoint. + dtype (`str` or `mindspore.dtype`, *optional*): + If provided, the weights will be converted to that type when loaded. + force_hooks (`bool`, *optional*, defaults to `False`): + Whether or not to force device hooks to be attached to the model even if all layers are dispatched to a + single device. + strict (`bool`, *optional*, defaults to `False`): + Whether to strictly enforce that the keys in the checkpoint state_dict match the keys of the model's + state_dict. + + Example: + + ```python + >>> from accelerate import init_empty_weights, load_checkpoint_and_dispatch + >>> from huggingface_hub import hf_hub_download + >>> from transformers import AutoConfig, AutoModelForCausalLM + + >>> # Download the Weights + >>> checkpoint = "EleutherAI/gpt-j-6B" + >>> weights_location = hf_hub_download(checkpoint, "pytorch_model.bin") + + >>> # Create a model and initialize it with empty weights + >>> config = AutoConfig.from_pretrained(checkpoint) + >>> with init_empty_weights(): + ... model = AutoModelForCausalLM.from_config(config) + + >>> # Load the checkpoint and dispatch it to the right devices + >>> model = load_checkpoint_and_dispatch( + ... model, weights_location, device_map="auto", no_split_module_classes=["GPTJBlock"] + ... ) + ``` + """ + + if isinstance(dtype, str): + # We accept "torch.float16" or just "float16" + dtype = dtype.replace("mindspore.", "") + dtype = getattr(ms, dtype) + + checkpoint_files = None + index_filename = None + if os.path.isfile(checkpoint): + if str(checkpoint).endswith(".json"): + index_filename = checkpoint + else: + checkpoint_files = [checkpoint] + elif os.path.isdir(checkpoint): + # check if the whole state dict is present + potential_state_safetensor = [f for f in os.listdir(checkpoint) if f == _SAFE_WEIGHTS_NAME] + if len(potential_state_safetensor) == 1: + checkpoint_files = [os.path.join(checkpoint, potential_state_safetensor[0])] + else: + # otherwise check for sharded checkpoints + potential_index = [f for f in os.listdir(checkpoint) if f.endswith(".index.json")] + if len(potential_index) == 0: + raise ValueError( + f"{checkpoint} is not a folder containing a `.index.json` file or a {_SAFE_WEIGHTS_NAME} file" + ) + elif len(potential_index) == 1: + index_filename = os.path.join(checkpoint, potential_index[0]) + else: + raise ValueError( + f"{checkpoint} containing more than one `.index.json` file, delete the irrelevant ones." + ) + else: + raise ValueError( + "`checkpoint` should be the path to a file containing a whole state dict, or the index of a sharded " + f"checkpoint, or a folder containing a sharded checkpoint or the whole state dict, but got {checkpoint}." + ) + + if index_filename is not None: + checkpoint_folder = os.path.split(index_filename)[0] + with open(index_filename) as f: + index = json.loads(f.read()) + + if "weight_map" in index: + index = index["weight_map"] + checkpoint_files = sorted(list(set(index.values()))) + checkpoint_files = [os.path.join(checkpoint_folder, f) for f in checkpoint_files] + + # Logic for missing/unexepected keys goes here. + unexpected_keys = set() + model_keys = set(model.parameters_dict().keys()) + is_sharded = index_filename is not None + cm = silence_mindspore_logger() if is_sharded else nullcontext() + with cm: + for checkpoint_file in checkpoint_files: + loaded_checkpoint = load_state_dict(checkpoint_file) + _ = _load_state_dict_into_model(model, loaded_checkpoint) + unexpected_keys.update(set(loaded_checkpoint.keys()) - model_keys) + del loaded_checkpoint + gc.collect() + + if not strict and len(unexpected_keys) > 0: + logger.warning( + f"Some weights of the model checkpoint at {checkpoint} were not used when" + f" initializing {model.__class__.__name__}: {unexpected_keys}. This may or may not be an issue - make sure that the checkpoint does not have unnecessary parameters, or that the model definition correctly corresponds to the checkpoint." # noqa E501 + ) + + return model + + +# ============================================= +# Sharded saving by huggingface huggingface_hub +# ============================================= + +TensorT = TypeVar("TensorT") +TensorSizeFn_T = Callable[[TensorT], int] +StorageIDFn_T = Callable[[TensorT], Optional[Any]] + +_MAX_SHARD_SIZE = "5GB" +_SAFETENSORS_WEIGHTS_FILE_PATTERN = "model{suffix}.safetensors" +_SIZE_UNITS = { + "TB": 10**12, + "GB": 10**9, + "MB": 10**6, + "KB": 10**3, +} + + +@dataclass +class StateDictSplit: + is_sharded: bool = field(init=False) + metadata: Dict[str, Any] + filename_to_tensors: Dict[str, List[str]] + tensor_to_filename: Dict[str, str] + + def __post_init__(self): + self.is_sharded = len(self.filename_to_tensors) > 1 + + +def split_state_dict_into_shards_factory( + state_dict: Dict[str, TensorT], + *, + get_storage_size: TensorSizeFn_T, + filename_pattern: str, + get_storage_id: StorageIDFn_T = lambda tensor: None, + max_shard_size: Union[int, str] = _MAX_SHARD_SIZE, +) -> StateDictSplit: + """ + Split a model state dictionary in shards so that each shard is smaller than a given size. + + The shards are determined by iterating through the `state_dict` in the order of its keys. There is no optimization + made to make each shard as close as possible to the maximum size passed. For example, if the limit is 10GB and we + have tensors of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as [6GB], [6+2GB], [6+2+2GB] and not + [6+2+2GB], [6+2GB], [6GB]. + + + + If one of the model's tensor is bigger than `max_shard_size`, it will end up in its own shard which will have a + size greater than `max_shard_size`. + + + + Args: + state_dict (`Dict[str, Tensor]`): + The state dictionary to save. + get_storage_size (`Callable[[Tensor], int]`): + A function that returns the size of a tensor when saved on disk in bytes. + get_storage_id (`Callable[[Tensor], Optional[Any]]`, *optional*): + A function that returns a unique identifier to a tensor storage. Multiple different tensors can share the + same underlying storage. This identifier is guaranteed to be unique and constant for this tensor's storage + during its lifetime. Two tensor storages with non-overlapping lifetimes may have the same id. + filename_pattern (`str`, *optional*): + The pattern to generate the files names in which the model will be saved. Pattern must be a string that + can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix` + max_shard_size (`int` or `str`, *optional*): + The maximum size of each shard, in bytes. Defaults to 5GB. + + Returns: + [`StateDictSplit`]: A `StateDictSplit` object containing the shards and the index to retrieve them. + """ + storage_id_to_tensors: Dict[Any, List[str]] = {} + + shard_list: List[Dict[str, TensorT]] = [] + current_shard: Dict[str, TensorT] = {} + current_shard_size = 0 + total_size = 0 + + if isinstance(max_shard_size, str): + max_shard_size = parse_size_to_int(max_shard_size) + + for key, tensor in state_dict.items(): + # when bnb serialization is used the weights in the state dict can be strings + # check: https://github.com/huggingface/transformers/pull/24416 for more details + if isinstance(tensor, str): + logger.info("Skipping tensor %s as it is a string (bnb serialization)", key) + continue + + # If a `tensor` shares the same underlying storage as another tensor, we put `tensor` in the same `block` + storage_id = get_storage_id(tensor) + if storage_id is not None: + if storage_id in storage_id_to_tensors: + # We skip this tensor for now and will reassign to correct shard later + storage_id_to_tensors[storage_id].append(key) + continue + else: + # This is the first tensor with this storage_id, we create a new entry + # in the storage_id_to_tensors dict => we will assign the shard id later + storage_id_to_tensors[storage_id] = [key] + + # Compute tensor size + tensor_size = get_storage_size(tensor) + + # If this tensor is bigger than the maximal size, we put it in its own shard + if tensor_size > max_shard_size: + total_size += tensor_size + shard_list.append({key: tensor}) + continue + + # If this tensor is going to tip up over the maximal size, we split. + # Current shard already has some tensors, we add it to the list of shards and create a new one. + if current_shard_size + tensor_size > max_shard_size: + shard_list.append(current_shard) + current_shard = {} + current_shard_size = 0 + + # Add the tensor to the current shard + current_shard[key] = tensor + current_shard_size += tensor_size + total_size += tensor_size + + # Add the last shard + if len(current_shard) > 0: + shard_list.append(current_shard) + nb_shards = len(shard_list) + + # Loop over the tensors that share the same storage and assign them together + for storage_id, keys in storage_id_to_tensors.items(): + # Let's try to find the shard where the first tensor of this storage is and put all tensors in the same shard + for shard in shard_list: + if keys[0] in shard: + for key in keys: + shard[key] = state_dict[key] + break + + # If we only have one shard, we return it => no need to build the index + if nb_shards == 1: + filename = filename_pattern.format(suffix="") + return StateDictSplit( + metadata={"total_size": total_size}, + filename_to_tensors={filename: list(state_dict.keys())}, + tensor_to_filename={key: filename for key in state_dict.keys()}, + ) + + # Now that each tensor is assigned to a shard, let's assign a filename to each shard + tensor_name_to_filename = {} + filename_to_tensors = {} + for idx, shard in enumerate(shard_list): + filename = filename_pattern.format(suffix=f"-{idx+1:05d}-of-{nb_shards:05d}") + for key in shard: + tensor_name_to_filename[key] = filename + filename_to_tensors[filename] = list(shard.keys()) + + # Build the index and return + return StateDictSplit( + metadata={"total_size": total_size}, + filename_to_tensors=filename_to_tensors, + tensor_to_filename=tensor_name_to_filename, + ) + + +def parse_size_to_int(size_as_str: str) -> int: + """ + Parse a size expressed as a string with digits and unit (like `"5MB"`) to an integer (in bytes). + + Supported units are "TB", "GB", "MB", "KB". + + Args: + size_as_str (`str`): The size to convert. Will be directly returned if an `int`. + + Example: + + ```py + >>> parse_size_to_int("5MB") + 5000000 + ``` + """ + size_as_str = size_as_str.strip() + + # Parse unit + unit = size_as_str[-2:].upper() + if unit not in _SIZE_UNITS: + raise ValueError(f"Unit '{unit}' not supported. Supported units are TB, GB, MB, KB. Got '{size_as_str}'.") + multiplier = _SIZE_UNITS[unit] + + # Parse value + try: + value = float(size_as_str[:-2].strip()) + except ValueError as e: + raise ValueError(f"Could not parse the size value from '{size_as_str}': {e}") from e + + return int(value * multiplier) + + +def split_torch_state_dict_into_shards( + state_dict: Dict[str, "ms.Tensor"], + *, + filename_pattern: str = _SAFETENSORS_WEIGHTS_FILE_PATTERN, + max_shard_size: Union[int, str] = _MAX_SHARD_SIZE, +) -> StateDictSplit: + """ + Split a model state dictionary in shards so that each shard is smaller than a given size. + + The shards are determined by iterating through the `state_dict` in the order of its keys. There is no optimization + made to make each shard as close as possible to the maximum size passed. For example, if the limit is 10GB and we + have tensors of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as [6GB], [6+2GB], [6+2+2GB] and not + [6+2+2GB], [6+2GB], [6GB]. + + + + + To save a model state dictionary to the disk, see [`save_torch_state_dict`]. This helper uses + `split_torch_state_dict_into_shards` under the hood. + + + + + + If one of the model's tensor is bigger than `max_shard_size`, it will end up in its own shard which will have a + size greater than `max_shard_size`. + + + + Args: + state_dict (`Dict[str, ms.Tensor]`): + The state dictionary to save. + filename_pattern (`str`, *optional*): + The pattern to generate the files names in which the model will be saved. Pattern must be a string that + can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix` + Defaults to `"model{suffix}.safetensors"`. + max_shard_size (`int` or `str`, *optional*): + The maximum size of each shard, in bytes. Defaults to 5GB. + + Returns: + [`StateDictSplit`]: A `StateDictSplit` object containing the shards and the index to retrieve them. + + Example: + ```py + >>> import json + >>> import os + >>> from safetensors.torch import save_file as safe_save_file + >>> from huggingface_hub import split_torch_state_dict_into_shards + + >>> def save_state_dict(state_dict: Dict[str, ms.Tensor], save_directory: str): + ... state_dict_split = split_torch_state_dict_into_shards(state_dict) + ... for filename, tensors in state_dict_split.filename_to_tensors.items(): + ... shard = {tensor: state_dict[tensor] for tensor in tensors} + ... safe_save_file( + ... shard, + ... os.path.join(save_directory, filename), + ... metadata={"format": "pt"}, + ... ) + ... if state_dict_split.is_sharded: + ... index = { + ... "metadata": state_dict_split.metadata, + ... "weight_map": state_dict_split.tensor_to_filename, + ... } + ... with open(os.path.join(save_directory, "model.safetensors.index.json"), "w") as f: + ... f.write(json.dumps(index, indent=2)) + ``` + """ + return split_state_dict_into_shards_factory( + state_dict, + max_shard_size=max_shard_size, + filename_pattern=filename_pattern, + get_storage_size=get_torch_storage_size, + ) + + +def get_torch_storage_size(tensor: "ms.Tensor") -> int: + """ + Taken from https://github.com/huggingface/safetensors/blob/08db34094e9e59e2f9218f2df133b7b4aaff5a99/bindings/python/py_src/safetensors/torch.py#L31C1-L41C59 + """ + return tensor.nelement() * _get_dtype_size(tensor.dtype) + + +@lru_cache() +def _get_dtype_size(dtype: "ms.Dtype") -> int: + """ + Taken from https://github.com/huggingface/safetensors/blob/08db34094e9e59e2f9218f2df133b7b4aaff5a99/bindings/python/py_src/safetensors/torch.py#L344 + """ + import mindspore as ms + + _SIZE = { + ms.int64: 8, + ms.float32: 4, + ms.int32: 4, + ms.bfloat16: 2, + ms.float16: 2, + ms.int16: 2, + ms.uint8: 1, + ms.int8: 1, + ms.bool_: 1, + ms.float64: 8, + } + return _SIZE[dtype] diff --git a/mindone/diffusers/models/modeling_outputs.py b/mindone/diffusers/models/modeling_outputs.py index 99ab69ccce..142a9c0b7d 100644 --- a/mindone/diffusers/models/modeling_outputs.py +++ b/mindone/diffusers/models/modeling_outputs.py @@ -17,3 +17,18 @@ class AutoencoderKLOutput(BaseOutput): """ latent: ms.Tensor + + +@dataclass +class Transformer2DModelOutput(BaseOutput): + """ + The output of [`Transformer2DModel`]. + + Args: + sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` or + `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): + The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability + distributions for the unnoised latent pixels. + """ + + sample: "ms.Tensor" # noqa: F821 diff --git a/mindone/diffusers/models/modeling_utils.py b/mindone/diffusers/models/modeling_utils.py index 1a7e7c7bac..00a4c5ed7c 100644 --- a/mindone/diffusers/models/modeling_utils.py +++ b/mindone/diffusers/models/modeling_utils.py @@ -13,10 +13,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import inspect +import json import os +import re from collections import OrderedDict from functools import partial -from typing import Any, Callable, List, Optional, Union +from pathlib import Path +from typing import Any, Callable, Optional, Union from huggingface_hub import create_repo from huggingface_hub.utils import validate_hf_hub_args @@ -24,24 +28,34 @@ import mindspore as ms from mindspore import nn, ops -from mindone.safetensors.mindspore import load_file as safe_load_file from mindone.safetensors.mindspore import save_file as safe_save_file from .. import __version__ from ..utils import ( CONFIG_NAME, - SAFETENSORS_FILE_EXTENSION, + SAFE_WEIGHTS_INDEX_NAME, SAFETENSORS_WEIGHTS_NAME, + WEIGHTS_INDEX_NAME, WEIGHTS_NAME, _add_variant, + _get_checkpoint_shard_files, _get_model_file, deprecate, logging, ) from ..utils.hub_utils import PushToHubMixin, load_or_create_model_card, populate_model_card +from .model_loading_utils import ( + _fetch_index_file, + _load_state_dict_into_model, + load_checkpoint_and_dispatch, + load_state_dict, + split_torch_state_dict_into_shards, +) logger = logging.get_logger(__name__) +_REGEX_SHARD = re.compile(r"(.*?)-\d{5}-of-\d{5}") + def _get_pt2ms_mappings(m): mappings = {} # pt_param_name: (ms_param_name, pt_param_to_ms_param_func) @@ -82,54 +96,6 @@ def get_parameter_dtype(module: nn.Cell) -> ms.Type: return params[0].dtype -def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None): - """ - Reads a checkpoint file, returning properly formatted errors if they arise. - """ - try: - file_extension = os.path.basename(checkpoint_file).split(".")[-1] - if file_extension == SAFETENSORS_FILE_EXTENSION: - return safe_load_file(checkpoint_file) - else: - raise NotImplementedError( - f"Only supports deserialization of weights file in safetensors format, but got {checkpoint_file}" - ) - except Exception as e: - try: - with open(checkpoint_file) as f: - if f.read().startswith("version"): - raise OSError( - "You seem to have cloned a repository without having git-lfs installed. Please install " - "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder " - "you cloned." - ) - else: - raise ValueError( - f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained " - "model. Make sure you have saved the model properly." - ) from e - except (UnicodeDecodeError, ValueError): - raise OSError( - f"Unable to load weights from checkpoint file for '{checkpoint_file}' " f"at '{checkpoint_file}'. " - ) - - -def _load_state_dict_into_model(model_to_load, state_dict: OrderedDict) -> List[str]: - # TODO: error_msgs is always empty for now. Maybe we need to rewrite MindSpore's `load_param_into_net`. - # Error msgs should contain caught exception like size mismatch instead of missing/unexpected keys. - # TODO: We should support loading float16 state_dict into float32 model, like PyTorch's behavior. - error_msgs = [] - # TODO: State dict loading in mindspore does not cast dtype correctly. We do it manually. It's might unsafe. - local_state = {k: v for k, v in model_to_load.parameters_and_names()} - for k, v in state_dict.items(): - if k in local_state: - v.set_dtype(local_state[k].dtype) - else: - pass # unexpect key keeps origin dtype - ms.load_param_into_net(model_to_load, state_dict, strict_load=True) - return error_msgs - - class ModelMixin(nn.Cell, PushToHubMixin): r""" Base class for all models. @@ -144,6 +110,7 @@ class ModelMixin(nn.Cell, PushToHubMixin): _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"] _supports_gradient_checkpointing = False _keys_to_ignore_on_load_unexpected = None + _no_split_modules = None def __init__(self): super().__init__() @@ -250,6 +217,7 @@ def save_pretrained( save_function: Optional[Callable] = None, safe_serialization: bool = True, variant: Optional[str] = None, + max_shard_size: Union[int, str] = "10GB", push_to_hub: bool = False, **kwargs, ): @@ -272,6 +240,13 @@ def save_pretrained( Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. variant (`str`, *optional*): If specified, weights are saved in the format `pytorch_model..bin`. + max_shard_size (`int` or `str`, defaults to `"10GB"`): + The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size + lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5GB"`). + If expressed as an integer, the unit is bytes. Note that this limit will be decreased after a certain + period of time (starting from Oct 2024) to allow users to upgrade to the latest version of `diffusers`. + This is to establish a common default size for this argument across different libraries in the Hugging + Face ecosystem (`transformers`, and `accelerate`, for example). push_to_hub (`bool`, *optional*, defaults to `False`): Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the repository you want to push to with `repo_id` (will default to the name of `save_directory` in your @@ -283,6 +258,14 @@ def save_pretrained( logger.error(f"Provided path ({save_directory}) should be a directory, not a file") return + weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME + weights_name = _add_variant(weights_name, variant) + weight_name_split = weights_name.split(".") + if len(weight_name_split) in [2, 3]: + weights_name_pattern = weight_name_split[0] + "{suffix}." + ".".join(weight_name_split[1:]) + else: + raise ValueError(f"Invalid {weights_name} provided.") + os.makedirs(save_directory, exist_ok=True) if push_to_hub: @@ -304,22 +287,64 @@ def save_pretrained( # Save the model state_dict = {k: v for k, v in model_to_save.parameters_and_names()} - weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME - weights_name = _add_variant(weights_name, variant) - # Save the model - if safe_serialization: - safe_save_file(state_dict, os.path.join(save_directory, weights_name), metadata={"format": "np"}) - else: - ms.save_checkpoint(state_dict, os.path.join(save_directory, weights_name)) + state_dict_split = split_torch_state_dict_into_shards( + state_dict, max_shard_size=max_shard_size, filename_pattern=weights_name_pattern + ) - logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}") + # Clean the folder from a previous save + if is_main_process: + for filename in os.listdir(save_directory): + if filename in state_dict_split.filename_to_tensors.keys(): + continue + full_filename = os.path.join(save_directory, filename) + if not os.path.isfile(full_filename): + continue + weights_without_ext = weights_name_pattern.replace(".bin", "").replace(".safetensors", "") + weights_without_ext = weights_without_ext.replace("{suffix}", "") + filename_without_ext = filename.replace(".bin", "").replace(".safetensors", "") + # make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005 + if ( + filename.startswith(weights_without_ext) + and _REGEX_SHARD.fullmatch(filename_without_ext) is not None + ): + os.remove(full_filename) + + for filename, tensors in state_dict_split.filename_to_tensors.items(): + shard = {tensor: state_dict[tensor] for tensor in tensors} + filepath = os.path.join(save_directory, filename) + if safe_serialization: + # At some point we will need to deal better with save_function (used for TPU and other distributed + # joyfulness), but for now this enough. + safe_save_file(shard, filepath, metadata={"format": "np"}) + else: + ms.save_checkpoint(shard, filepath) + + if state_dict_split.is_sharded: + index = { + "metadata": state_dict_split.metadata, + "weight_map": state_dict_split.tensor_to_filename, + } + save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME + save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant)) + # Save the index as well + with open(save_index_file, "w", encoding="utf-8") as f: + content = json.dumps(index, indent=2, sort_keys=True) + "\n" + f.write(content) + logger.info( + f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be " + f"split in {len(state_dict_split.filename_to_tensors)} checkpoint shards. You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) + else: + path_to_weights = os.path.join(save_directory, weights_name) + logger.info(f"Model weights saved in {path_to_weights}") if push_to_hub: # Create a new empty model card and eventually tag it model_card = load_or_create_model_card(repo_id, token=token) model_card = populate_model_card(model_card) - model_card.save(os.path.join(save_directory, "README.md")) + model_card.save(Path(save_directory, "README.md").as_posix()) self._upload_folder( save_directory, @@ -356,9 +381,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. - resume_download (`bool`, *optional*, defaults to `False`): - Whether or not to resume downloading the model weights and configuration files. If set to `False`, any - incompletely downloaded files are deleted. + resume_download: + Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1 + of Diffusers. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. @@ -419,7 +444,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False) force_download = kwargs.pop("force_download", False) from_flax = kwargs.pop("from_flax", False) - resume_download = kwargs.pop("resume_download", False) + resume_download = kwargs.pop("resume_download", None) proxies = kwargs.pop("proxies", None) output_loading_info = kwargs.pop("output_loading_info", False) local_files_only = kwargs.pop("local_files_only", None) @@ -461,12 +486,49 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P **kwargs, ) + # Determine if we're loading from a directory of sharded checkpoints. + is_sharded = False + index_file = None + is_local = os.path.isdir(pretrained_model_name_or_path) + index_file = _fetch_index_file( + is_local=is_local, + pretrained_model_name_or_path=pretrained_model_name_or_path, + subfolder=subfolder or "", + use_safetensors=use_safetensors, + cache_dir=cache_dir, + variant=variant, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + user_agent=user_agent, + commit_hash=commit_hash, + ) + if index_file is not None and index_file.is_file(): + is_sharded = True + # load model model_file = None if from_flax: raise NotImplementedError("loading flax checkpoint in mindspore model is not yet supported.") else: - if use_safetensors: + if is_sharded: + sharded_ckpt_cached_folder, sharded_metadata = _get_checkpoint_shard_files( + pretrained_model_name_or_path, + index_file, + cache_dir=cache_dir, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + token=token, + user_agent=user_agent, + revision=revision, + subfolder=subfolder or "", + ) + + elif use_safetensors and not is_sharded: try: model_file = _get_model_file( pretrained_model_name_or_path, @@ -483,10 +545,14 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P commit_hash=commit_hash, ) except IOError as e: + logger.error(f"An error occurred while trying to fetch {pretrained_model_name_or_path}: {e}") if not allow_pickle: - raise e - pass - if model_file is None: + raise + logger.warning( + "Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead." + ) + + if model_file is None and not is_sharded: model_file = _get_model_file( pretrained_model_name_or_path, weights_name=_add_variant(WEIGHTS_NAME, variant), @@ -504,23 +570,31 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P model = cls.from_config(config, **unused_kwargs) - state_dict = load_state_dict(model_file, variant=variant) - model._convert_deprecated_attention_blocks(state_dict) - - model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model( - model, - state_dict, - model_file, - pretrained_model_name_or_path, - ignore_mismatched_sizes=ignore_mismatched_sizes, - ) + if is_sharded: + load_checkpoint_and_dispatch( + model, + sharded_ckpt_cached_folder, + dtype=mindspore_dtype, + strict=True, + ) + else: + state_dict = load_state_dict(model_file, variant=variant) + model._convert_deprecated_attention_blocks(state_dict) + + model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model( + model, + state_dict, + model_file, + pretrained_model_name_or_path, + ignore_mismatched_sizes=ignore_mismatched_sizes, + ) - loading_info = { - "missing_keys": missing_keys, - "unexpected_keys": unexpected_keys, - "mismatched_keys": mismatched_keys, - "error_msgs": error_msgs, - } + loading_info = { + "missing_keys": missing_keys, + "unexpected_keys": unexpected_keys, + "mismatched_keys": mismatched_keys, + "error_msgs": error_msgs, + } if mindspore_dtype is not None and not isinstance(mindspore_dtype, ms.Type): raise ValueError( @@ -533,7 +607,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # Set model in evaluation mode to deactivate DropOut modules by default model.set_train(False) - if output_loading_info: + if not is_sharded and output_loading_info: return model, loading_info return model @@ -643,6 +717,15 @@ def _find_mismatched_keys( return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs + @classmethod + def _get_signature_keys(cls, obj): + parameters = inspect.signature(obj.__init__).parameters + required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty} + optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty}) + expected_modules = set(required_parameters.keys()) - {"self"} + + return expected_modules, optional_parameters + def to(self, dtype: Optional[ms.Type] = None): for p in self.get_parameters(): p.set_dtype(dtype) @@ -746,3 +829,58 @@ def recursive_find_attn_block(name, module): state_dict[f"{path}.to_out.0.weight"] = state_dict.pop(f"{path}.proj_attn.weight") if f"{path}.proj_attn.bias" in state_dict: state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias") + + +class LegacyModelMixin(ModelMixin): + r""" + A subclass of `ModelMixin` to resolve class mapping from legacy classes (like `Transformer2DModel`) to more + pipeline-specific classes (like `DiTTransformer2DModel`). + """ + + @classmethod + @validate_hf_hub_args + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): + # To prevent depedency import problem. + from .model_loading_utils import _fetch_remapped_cls_from_config + + # Create a copy of the kwargs so that we don't mess with the keyword arguments in the downstream calls. + kwargs_copy = kwargs.copy() + + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", None) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + + # Load config if we don't provide a configuration + config_path = pretrained_model_name_or_path + + user_agent = { + "diffusers": __version__, + "file_type": "model", + "framework": "pytorch", + } + + # load config + config, _, _ = cls.load_config( + config_path, + cache_dir=cache_dir, + return_unused_kwargs=True, + return_commit_hash=True, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + **kwargs, + ) + # resolve remapping + remapped_class = _fetch_remapped_cls_from_config(config, cls) + + return remapped_class.from_pretrained(pretrained_model_name_or_path, **kwargs_copy) diff --git a/mindone/diffusers/models/normalization.py b/mindone/diffusers/models/normalization.py index 024e8ce008..5587c433ee 100644 --- a/mindone/diffusers/models/normalization.py +++ b/mindone/diffusers/models/normalization.py @@ -184,7 +184,8 @@ def __init__( raise ValueError(f"unknown norm_type {norm_type}") def construct(self, x: ms.Tensor, conditioning_embedding: ms.Tensor) -> ms.Tensor: - emb = self.linear(self.silu(conditioning_embedding)) + # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT) + emb = self.linear(self.silu(conditioning_embedding).to(x.dtype)) scale, shift = ops.chunk(emb, 2, axis=1) x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :] return x diff --git a/mindone/diffusers/models/resnet.py b/mindone/diffusers/models/resnet.py index af00e62e58..f7ff9abca4 100644 --- a/mindone/diffusers/models/resnet.py +++ b/mindone/diffusers/models/resnet.py @@ -85,8 +85,6 @@ def __init__( self.output_scale_factor = output_scale_factor self.time_embedding_norm = time_embedding_norm - conv_cls = nn.Conv2d - if groups_out is None: groups_out = groups @@ -97,7 +95,7 @@ def __init__( else: raise ValueError(f" unsupported time_embedding_norm: {self.time_embedding_norm}") - self.conv1 = conv_cls( + self.conv1 = nn.Conv2d( in_channels, out_channels, kernel_size=3, stride=1, pad_mode="pad", padding=1, has_bias=True ) @@ -111,7 +109,7 @@ def __init__( self.dropout = nn.Dropout(p=dropout) conv_2d_out_channels = conv_2d_out_channels or out_channels - self.conv2 = conv_cls( + self.conv2 = nn.Conv2d( out_channels, conv_2d_out_channels, kernel_size=3, stride=1, pad_mode="pad", padding=1, has_bias=True ) @@ -127,7 +125,7 @@ def __init__( self.conv_shortcut = None if self.use_in_shortcut: - self.conv_shortcut = conv_cls( + self.conv_shortcut = nn.Conv2d( in_channels, conv_2d_out_channels, kernel_size=1, @@ -202,8 +200,8 @@ class ResnetBlock2D(nn.Cell): eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization. non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use. time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config. - By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" - for a stronger conditioning with scale and shift. + By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" for a + stronger conditioning with scale and shift. kernel (`ms.Tensor`, optional, default to None): FIR filter, see [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`]. output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output. @@ -261,23 +259,20 @@ def __init__( self.time_embedding_norm = time_embedding_norm self.skip_time_act = skip_time_act - linear_cls = nn.Dense - conv_cls = nn.Conv2d - if groups_out is None: groups_out = groups self.norm1 = GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) - self.conv1 = conv_cls( + self.conv1 = nn.Conv2d( in_channels, out_channels, kernel_size=3, stride=1, pad_mode="pad", padding=1, has_bias=True ) if temb_channels is not None: if self.time_embedding_norm == "default": - self.time_emb_proj = linear_cls(temb_channels, out_channels) + self.time_emb_proj = nn.Dense(temb_channels, out_channels) elif self.time_embedding_norm == "scale_shift": - self.time_emb_proj = linear_cls(temb_channels, 2 * out_channels) + self.time_emb_proj = nn.Dense(temb_channels, 2 * out_channels) else: raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ") else: @@ -287,7 +282,7 @@ def __init__( self.dropout = nn.Dropout(p=dropout) conv_2d_out_channels = conv_2d_out_channels or out_channels - self.conv2 = conv_cls( + self.conv2 = nn.Conv2d( out_channels, conv_2d_out_channels, kernel_size=3, stride=1, pad_mode="pad", padding=1, has_bias=True ) @@ -315,7 +310,7 @@ def __init__( self.conv_shortcut = None if self.use_in_shortcut: - self.conv_shortcut = conv_cls( + self.conv_shortcut = nn.Conv2d( in_channels, conv_2d_out_channels, kernel_size=1, diff --git a/mindone/diffusers/models/transformers/__init__.py b/mindone/diffusers/models/transformers/__init__.py index ec011ad544..ce76773743 100644 --- a/mindone/diffusers/models/transformers/__init__.py +++ b/mindone/diffusers/models/transformers/__init__.py @@ -1,4 +1,7 @@ +from .dit_transformer_2d import DiTTransformer2DModel from .dual_transformer_2d import DualTransformer2DModel +from .hunyuan_transformer_2d import HunyuanDiT2DModel +from .pixart_transformer_2d import PixArtTransformer2DModel from .prior_transformer import PriorTransformer from .t5_film_transformer import T5FilmDecoder from .transformer_2d import Transformer2DModel diff --git a/mindone/diffusers/models/transformers/dit_transformer_2d.py b/mindone/diffusers/models/transformers/dit_transformer_2d.py new file mode 100644 index 0000000000..72d2511e34 --- /dev/null +++ b/mindone/diffusers/models/transformers/dit_transformer_2d.py @@ -0,0 +1,210 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, Optional + +import mindspore as ms +from mindspore import nn, ops + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import logging +from ..attention import BasicTransformerBlock +from ..embeddings import PatchEmbed +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import LayerNorm + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class DiTTransformer2DModel(ModelMixin, ConfigMixin): + r""" + A 2D Transformer model as introduced in DiT (https://arxiv.org/abs/2212.09748). + + Parameters: + num_attention_heads (int, optional, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (int, optional, defaults to 72): The number of channels in each head. + in_channels (int, defaults to 4): The number of channels in the input. + out_channels (int, optional): + The number of channels in the output. Specify this parameter if the output channel number differs from the + input. + num_layers (int, optional, defaults to 28): The number of layers of Transformer blocks to use. + dropout (float, optional, defaults to 0.0): The dropout probability to use within the Transformer blocks. + norm_num_groups (int, optional, defaults to 32): + Number of groups for group normalization within Transformer blocks. + attention_bias (bool, optional, defaults to True): + Configure if the Transformer blocks' attention should contain a bias parameter. + sample_size (int, defaults to 32): + The width of the latent images. This parameter is fixed during training. + patch_size (int, defaults to 2): + Size of the patches the model processes, relevant for architectures working on non-sequential data. + activation_fn (str, optional, defaults to "gelu-approximate"): + Activation function to use in feed-forward networks within Transformer blocks. + num_embeds_ada_norm (int, optional, defaults to 1000): + Number of embeddings for AdaLayerNorm, fixed during training and affects the maximum denoising steps during + inference. + upcast_attention (bool, optional, defaults to False): + If true, upcasts the attention mechanism dimensions for potentially improved performance. + norm_type (str, optional, defaults to "ada_norm_zero"): + Specifies the type of normalization used, can be 'ada_norm_zero'. + norm_elementwise_affine (bool, optional, defaults to False): + If true, enables element-wise affine parameters in the normalization layers. + norm_eps (float, optional, defaults to 1e-5): + A small constant added to the denominator in normalization layers to prevent division by zero. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 72, + in_channels: int = 4, + out_channels: Optional[int] = None, + num_layers: int = 28, + dropout: float = 0.0, + norm_num_groups: int = 32, + attention_bias: bool = True, + sample_size: int = 32, + patch_size: int = 2, + activation_fn: str = "gelu-approximate", + num_embeds_ada_norm: Optional[int] = 1000, + upcast_attention: bool = False, + norm_type: str = "ada_norm_zero", + norm_elementwise_affine: bool = False, + norm_eps: float = 1e-5, + ): + super().__init__() + + # Validate inputs. + if norm_type != "ada_norm_zero": + raise NotImplementedError( + f"Forward pass is not implemented when `patch_size` is not None and `norm_type` is '{norm_type}'." + ) + elif norm_type == "ada_norm_zero" and num_embeds_ada_norm is None: + raise ValueError( + f"When using a `patch_size` and this `norm_type` ({norm_type}), `num_embeds_ada_norm` cannot be None." + ) + + # Set some common variables used across the board. + self.attention_head_dim = attention_head_dim + self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim + self.out_channels = in_channels if out_channels is None else out_channels + self.gradient_checkpointing = False + + # 2. Initialize the position embedding and transformer blocks. + self.height = self.config.sample_size + self.width = self.config.sample_size + + self.patch_size = self.config.patch_size + self.pos_embed = PatchEmbed( + height=self.config.sample_size, + width=self.config.sample_size, + patch_size=self.config.patch_size, + in_channels=self.config.in_channels, + embed_dim=self.inner_dim, + ) + + self.transformer_blocks = nn.CellList( + [ + BasicTransformerBlock( + self.inner_dim, + self.config.num_attention_heads, + self.config.attention_head_dim, + dropout=self.config.dropout, + activation_fn=self.config.activation_fn, + num_embeds_ada_norm=self.config.num_embeds_ada_norm, + attention_bias=self.config.attention_bias, + upcast_attention=self.config.upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=self.config.norm_elementwise_affine, + norm_eps=self.config.norm_eps, + ) + for _ in range(self.config.num_layers) + ] + ) + + # 3. Output blocks. + self.norm_out = LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out_1 = nn.Dense(self.inner_dim, 2 * self.inner_dim) + self.proj_out_2 = nn.Dense(self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels) + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def construct( + self, + hidden_states: ms.Tensor, + timestep: Optional[ms.Tensor] = None, + class_labels: Optional[ms.Tensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + return_dict: bool = False, + ): + """ + The [`DiTTransformer2DModel`] forward method. + + Args: + hidden_states (`ms.Tensor` of shape `(batch size, num latent pixels)` if discrete, `ms.Tensor` of shape `(batch size, channel, height, width)` if continuous): # noqa: E501 + Input `hidden_states`. + timestep ( `ms.Tensor`, *optional*): + Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. + class_labels ( `ms.Tensor` of shape `(batch size, num classes)`, *optional*): + Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in + `AdaLayerZeroNorm`. + cross_attention_kwargs ( `Dict[str, Any]`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + return_dict (`bool`, *optional*, defaults to `False`): + Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + # 1. Input + height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size + hidden_states = self.pos_embed(hidden_states) + + # 2. Blocks + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + + # 3. Output + conditioning = self.transformer_blocks[0].norm1.emb(timestep, class_labels, hidden_dtype=hidden_states.dtype) + shift, scale = self.proj_out_1(ops.silu(conditioning)).chunk(2, axis=1) + hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] + hidden_states = self.proj_out_2(hidden_states) + + # unpatchify + height = width = int(hidden_states.shape[1] ** 0.5) + hidden_states = hidden_states.reshape(-1, height, width, self.patch_size, self.patch_size, self.out_channels) + # hidden_states = ops.einsum("nhwpqc->nchpwq", hidden_states) + hidden_states = hidden_states.transpose(0, 5, 1, 3, 2, 4) + output = hidden_states.reshape(-1, self.out_channels, height * self.patch_size, width * self.patch_size) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/mindone/diffusers/models/transformers/dual_transformer_2d.py b/mindone/diffusers/models/transformers/dual_transformer_2d.py index baa00e3013..56b94453b7 100644 --- a/mindone/diffusers/models/transformers/dual_transformer_2d.py +++ b/mindone/diffusers/models/transformers/dual_transformer_2d.py @@ -15,7 +15,8 @@ from mindspore import nn -from .transformer_2d import Transformer2DModel, Transformer2DModelOutput +from ..modeling_outputs import Transformer2DModelOutput +from .transformer_2d import Transformer2DModel class DualTransformer2DModel(nn.Cell): @@ -119,13 +120,14 @@ def construct( A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + return_dict (`bool`, *optional*, defaults to `False`): + Whether or not to return a [`models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. Returns: - [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`: - [`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When - returning a tuple, the first element is the sample tensor. + [`~models.transformers.transformer_2d.Transformer2DModelOutput`] or `tuple`: + [`~models.transformers.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is the sample tensor. """ input_states = hidden_states diff --git a/mindone/diffusers/models/transformers/hunyuan_transformer_2d.py b/mindone/diffusers/models/transformers/hunyuan_transformer_2d.py new file mode 100644 index 0000000000..b8b7e30826 --- /dev/null +++ b/mindone/diffusers/models/transformers/hunyuan_transformer_2d.py @@ -0,0 +1,531 @@ +# Copyright 2024 HunyuanDiT Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict, Optional, Union + +import mindspore as ms +from mindspore import nn, ops + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import logging +from ..activations import SiLU +from ..attention import FeedForward +from ..attention_processor import Attention, AttentionProcessor, HunyuanAttnProcessor +from ..embeddings import HunyuanCombinedTimestepTextSizeStyleEmbedding, PatchEmbed, PixArtAlphaTextProjection +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import AdaLayerNormContinuous, LayerNorm + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class FP32LayerNorm(LayerNorm): + def construct(self, inputs: ms.Tensor) -> ms.Tensor: + origin_dtype = inputs.dtype + x, _, _ = self.layer_norm(inputs.float(), self.weight.float(), self.bias.float()) + return x.to(origin_dtype) + + +class AdaLayerNormShift(nn.Cell): + r""" + Norm layer modified to incorporate timestep embeddings. + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + num_embeddings (`int`): The size of the embeddings dictionary. + """ + + def __init__(self, embedding_dim: int, elementwise_affine=True, eps=1e-6): + super().__init__() + self.silu = SiLU() + self.linear = nn.Dense(embedding_dim, embedding_dim) + self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=elementwise_affine, eps=eps) + + def construct(self, x: ms.Tensor, emb: ms.Tensor) -> ms.Tensor: + shift = self.linear(self.silu(emb.to(ms.float32)).to(emb.dtype)) + x = self.norm(x) + shift.unsqueeze(dim=1) + return x + + +class HunyuanDiTBlock(nn.Cell): + r""" + Transformer block used in Hunyuan-DiT model (https://github.com/Tencent/HunyuanDiT). Allow skip connection and + QKNorm + + Parameters: + dim (`int`): + The number of channels in the input and output. + num_attention_heads (`int`): + The number of headsto use for multi-head attention. + cross_attention_dim (`int`,*optional*): + The size of the encoder_hidden_states vector for cross attention. + dropout(`float`, *optional*, defaults to 0.0): + The dropout probability to use. + activation_fn (`str`,*optional*, defaults to `"geglu"`): + Activation function to be used in feed-forward. . + norm_elementwise_affine (`bool`, *optional*, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + norm_eps (`float`, *optional*, defaults to 1e-6): + A small constant added to the denominator in normalization layers to prevent division by zero. + final_dropout (`bool` *optional*, defaults to False): + Whether to apply a final dropout after the last feed-forward layer. + ff_inner_dim (`int`, *optional*): + The size of the hidden layer in the feed-forward block. Defaults to `None`. + ff_bias (`bool`, *optional*, defaults to `True`): + Whether to use bias in the feed-forward block. + skip (`bool`, *optional*, defaults to `False`): + Whether to use skip connection. Defaults to `False` for down-blocks and mid-blocks. + qk_norm (`bool`, *optional*, defaults to `True`): + Whether to use normalization in QK calculation. Defaults to `True`. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + cross_attention_dim: int = 1024, + dropout=0.0, + activation_fn: str = "geglu", + norm_elementwise_affine: bool = True, + norm_eps: float = 1e-6, + final_dropout: bool = False, + ff_inner_dim: Optional[int] = None, + ff_bias: bool = True, + skip: bool = False, + qk_norm: bool = True, + ): + super().__init__() + + # Define 3 blocks. Each block has its own normalization layer. + # NOTE: when new version comes, check norm2 and norm 3 + # 1. Self-Attn + self.norm1 = AdaLayerNormShift(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + + self.attn1 = Attention( + query_dim=dim, + cross_attention_dim=None, + dim_head=dim // num_attention_heads, + heads=num_attention_heads, + qk_norm="layer_norm" if qk_norm else None, + eps=1e-6, + bias=True, + processor=HunyuanAttnProcessor(), + ) + + # 2. Cross-Attn + self.norm2 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine) + + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + dim_head=dim // num_attention_heads, + heads=num_attention_heads, + qk_norm="layer_norm" if qk_norm else None, + eps=1e-6, + bias=True, + processor=HunyuanAttnProcessor(), + ) + # 3. Feed-forward + self.norm3 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine) + + self.ff = FeedForward( + dim, + dropout=dropout, # 0.0 + activation_fn=activation_fn, # approx GeLU + final_dropout=final_dropout, # 0.0 + inner_dim=ff_inner_dim, # int(dim * mlp_ratio) + bias=ff_bias, + ) + + # 4. Skip Connection + if skip: + self.skip_norm = FP32LayerNorm(2 * dim, norm_eps, elementwise_affine=True) + self.skip_linear = nn.Dense(2 * dim, dim) + else: + self.skip_linear = None + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + # Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def construct( + self, + hidden_states: ms.Tensor, + encoder_hidden_states: Optional[ms.Tensor] = None, + temb: Optional[ms.Tensor] = None, + image_rotary_emb=None, + skip=None, + ) -> ms.Tensor: + # Notice that normalization is always applied before the real computation in the following blocks. + # 0. Long Skip Connection + cat = None + if self.skip_linear is not None: + cat = ops.cat([hidden_states, skip], axis=-1) + cat = self.skip_norm(cat) + hidden_states = self.skip_linear(cat) + + # 1. Self-Attention + norm_hidden_states = self.norm1(hidden_states, temb) # checked: self.norm1 is correct + attn_output = self.attn1( + norm_hidden_states, + image_rotary_emb=image_rotary_emb, + ) + hidden_states = hidden_states + attn_output + + # 2. Cross-Attention + hidden_states = hidden_states + self.attn2( + self.norm2(hidden_states), + encoder_hidden_states=encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + ) + + # FFN Layer ### TODO: switch norm2 and norm3 in the state dict + mlp_inputs = self.norm3(hidden_states) + hidden_states = hidden_states + self.ff(mlp_inputs) + + return hidden_states + + +class HunyuanDiT2DModel(ModelMixin, ConfigMixin): + """ + HunYuanDiT: Diffusion model with a Transformer backbone. + + Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): + The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): + The number of channels in each head. + in_channels (`int`, *optional*): + The number of channels in the input and output (specify if the input is **continuous**). + patch_size (`int`, *optional*): + The size of the patch to use for the input. + activation_fn (`str`, *optional*, defaults to `"geglu"`): + Activation function to use in feed-forward. + sample_size (`int`, *optional*): + The width of the latent images. This is fixed during training since it is used to learn a number of + position embeddings. + dropout (`float`, *optional*, defaults to 0.0): + The dropout probability to use. + cross_attention_dim (`int`, *optional*): + The number of dimension in the clip text embedding. + hidden_size (`int`, *optional*): + The size of hidden layer in the conditioning embedding layers. + num_layers (`int`, *optional*, defaults to 1): + The number of layers of Transformer blocks to use. + mlp_ratio (`float`, *optional*, defaults to 4.0): + The ratio of the hidden layer size to the input size. + learn_sigma (`bool`, *optional*, defaults to `True`): + Whether to predict variance. + cross_attention_dim_t5 (`int`, *optional*): + The number dimensions in t5 text embedding. + pooled_projection_dim (`int`, *optional*): + The size of the pooled projection. + text_len (`int`, *optional*): + The length of the clip text embedding. + text_len_t5 (`int`, *optional*): + The length of the T5 text embedding. + """ + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + patch_size: Optional[int] = None, + activation_fn: str = "gelu-approximate", + sample_size=32, + hidden_size=1152, + num_layers: int = 28, + mlp_ratio: float = 4.0, + learn_sigma: bool = True, + cross_attention_dim: int = 1024, + norm_type: str = "layer_norm", + cross_attention_dim_t5: int = 2048, + pooled_projection_dim: int = 1024, + text_len: int = 77, + text_len_t5: int = 256, + ): + super().__init__() + self.out_channels = in_channels * 2 if learn_sigma else in_channels + self.num_heads = num_attention_heads + self.inner_dim = num_attention_heads * attention_head_dim + + self.text_embedder = PixArtAlphaTextProjection( + in_features=cross_attention_dim_t5, + hidden_size=cross_attention_dim_t5 * 4, + out_features=cross_attention_dim, + act_fn="silu_fp32", + ) + + self.text_embedding_padding = ms.Parameter( + ops.randn(text_len + text_len_t5, cross_attention_dim, dtype=ms.float32), + name="text_embedding_padding", + ) + + self.pos_embed = PatchEmbed( + height=sample_size, + width=sample_size, + in_channels=in_channels, + embed_dim=hidden_size, + patch_size=patch_size, + pos_embed_type=None, + ) + + self.time_extra_emb = HunyuanCombinedTimestepTextSizeStyleEmbedding( + hidden_size, + pooled_projection_dim=pooled_projection_dim, + seq_len=text_len_t5, + cross_attention_dim=cross_attention_dim_t5, + ) + + # HunyuanDiT Blocks + self.blocks = nn.CellList( + [ + HunyuanDiTBlock( + dim=self.inner_dim, + num_attention_heads=self.config.num_attention_heads, + activation_fn=activation_fn, + ff_inner_dim=int(self.inner_dim * mlp_ratio), + cross_attention_dim=cross_attention_dim, + qk_norm=True, # See http://arxiv.org/abs/2302.05442 for details. + skip=layer > num_layers // 2, + ) + for layer in range(num_layers) + ] + ) + + self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out = nn.Dense(self.inner_dim, patch_size * patch_size * self.out_channels, has_bias=True) + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: nn.Cell, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.name_cells().items(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.name_cells().items(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: nn.Cell, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.name_cells().items(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.name_cells().items(): + fn_recursive_attn_processor(name, module, processor) + + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + self.set_attn_processor(HunyuanAttnProcessor()) + + def construct( + self, + hidden_states, + timestep, + encoder_hidden_states=None, + text_embedding_mask=None, + encoder_hidden_states_t5=None, + text_embedding_mask_t5=None, + image_meta_size=None, + style=None, + image_rotary_emb=None, + return_dict=False, + ): + """ + The [`HunyuanDiT2DModel`] forward method. + + Args: + hidden_states (`ms.Tensor` of shape `(batch size, dim, height, width)`): + The input tensor. + timestep ( `ms.Tensor`, *optional*): + Used to indicate denoising step. + encoder_hidden_states ( `ms.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*): + Conditional embeddings for cross attention layer. This is the output of `BertModel`. + text_embedding_mask: ms.Tensor + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. This is the output + of `BertModel`. + encoder_hidden_states_t5 ( `ms.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*): + Conditional embeddings for cross attention layer. This is the output of T5 Text Encoder. + text_embedding_mask_t5: ms.Tensor + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. This is the output + of T5 Text Encoder. + image_meta_size (ms.Tensor): + Conditional embedding indicate the image sizes + style: ms.Tensor: + Conditional embedding indicate the style + image_rotary_emb (`ms.Tensor`): + The image rotary embeddings to apply on query and key tensors during attention calculation. + return_dict: bool + Whether to return a dictionary. + """ + + height, width = hidden_states.shape[-2:] + + hidden_states = self.pos_embed(hidden_states) + + temb = self.time_extra_emb( + timestep, encoder_hidden_states_t5, image_meta_size, style, hidden_dtype=timestep.dtype + ) # [B, D] + + # text projection + batch_size, sequence_length, _ = encoder_hidden_states_t5.shape + encoder_hidden_states_t5 = self.text_embedder( + encoder_hidden_states_t5.view(-1, encoder_hidden_states_t5.shape[-1]) + ) + encoder_hidden_states_t5 = encoder_hidden_states_t5.view(batch_size, sequence_length, -1) + + encoder_hidden_states = ops.cat([encoder_hidden_states, encoder_hidden_states_t5], axis=1) + text_embedding_mask = ops.cat([text_embedding_mask, text_embedding_mask_t5], axis=-1) + text_embedding_mask = text_embedding_mask.unsqueeze(2).bool() + + encoder_hidden_states = ops.where(text_embedding_mask, encoder_hidden_states, self.text_embedding_padding) + + skips = [] + for layer, block in enumerate(self.blocks): + if layer > self.config["num_layers"] // 2: + skip = skips.pop() + hidden_states = block( + hidden_states, + temb=temb, + encoder_hidden_states=encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + skip=skip, + ) # (N, L, D) + else: + hidden_states = block( + hidden_states, + temb=temb, + encoder_hidden_states=encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + ) # (N, L, D) + + if layer < (self.config["num_layers"] // 2 - 1): + skips.append(hidden_states) + + # final layer + hidden_states = self.norm_out(hidden_states, temb.to(ms.float32)) + hidden_states = self.proj_out(hidden_states) + # (N, L, patch_size ** 2 * out_channels) + + # unpatchify: (N, out_channels, H, W) + patch_size = self.pos_embed.patch_size + height = height // patch_size + width = width // patch_size + + hidden_states = hidden_states.reshape( + hidden_states.shape[0], height, width, patch_size, patch_size, self.out_channels + ) + # hidden_states = ops.einsum("nhwpqc->nchpwq", hidden_states) + hidden_states = hidden_states.transpose(0, 5, 1, 3, 2, 4) + output = hidden_states.reshape( + hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size + ) + if not return_dict: + return (output,) + return Transformer2DModelOutput(sample=output) + + # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking + def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: + """ + Sets the attention processor to use [feed forward + chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). + + Parameters: + chunk_size (`int`, *optional*): + The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually + over each tensor of dim=`dim`. + dim (`int`, *optional*, defaults to `0`): + The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) + or dim=1 (sequence length). + """ + if dim not in [0, 1]: + raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") + + # By default chunk size is 1 + chunk_size = chunk_size or 1 + + def fn_recursive_feed_forward(module: nn.Cell, chunk_size: int, dim: int): + if hasattr(module, "set_chunk_feed_forward"): + module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) + + for child in module.name_cells().values(): + fn_recursive_feed_forward(child, chunk_size, dim) + + for module in self.name_cells().values(): + fn_recursive_feed_forward(module, chunk_size, dim) + + # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking + def disable_forward_chunking(self): + def fn_recursive_feed_forward(module: nn.Cell, chunk_size: int, dim: int): + if hasattr(module, "set_chunk_feed_forward"): + module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) + + for child in module.name_cells().values(): + fn_recursive_feed_forward(child, chunk_size, dim) + + for module in self.name_cells().values(): + fn_recursive_feed_forward(module, None, 0) diff --git a/mindone/diffusers/models/transformers/pixart_transformer_2d.py b/mindone/diffusers/models/transformers/pixart_transformer_2d.py new file mode 100644 index 0000000000..7b6002fba5 --- /dev/null +++ b/mindone/diffusers/models/transformers/pixart_transformer_2d.py @@ -0,0 +1,311 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, Optional + +import mindspore as ms +from mindspore import nn, ops + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import logging +from ..attention import BasicTransformerBlock +from ..embeddings import PatchEmbed, PixArtAlphaTextProjection +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import AdaLayerNormSingle, LayerNorm + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class PixArtTransformer2DModel(ModelMixin, ConfigMixin): + r""" + A 2D Transformer model as introduced in PixArt family of models (https://arxiv.org/abs/2310.00426, + https://arxiv.org/abs/2403.04692). + + Parameters: + num_attention_heads (int, optional, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (int, optional, defaults to 72): The number of channels in each head. + in_channels (int, defaults to 4): The number of channels in the input. + out_channels (int, optional): + The number of channels in the output. Specify this parameter if the output channel number differs from the + input. + num_layers (int, optional, defaults to 28): The number of layers of Transformer blocks to use. + dropout (float, optional, defaults to 0.0): The dropout probability to use within the Transformer blocks. + norm_num_groups (int, optional, defaults to 32): + Number of groups for group normalization within Transformer blocks. + cross_attention_dim (int, optional): + The dimensionality for cross-attention layers, typically matching the encoder's hidden dimension. + attention_bias (bool, optional, defaults to True): + Configure if the Transformer blocks' attention should contain a bias parameter. + sample_size (int, defaults to 128): + The width of the latent images. This parameter is fixed during training. + patch_size (int, defaults to 2): + Size of the patches the model processes, relevant for architectures working on non-sequential data. + activation_fn (str, optional, defaults to "gelu-approximate"): + Activation function to use in feed-forward networks within Transformer blocks. + num_embeds_ada_norm (int, optional, defaults to 1000): + Number of embeddings for AdaLayerNorm, fixed during training and affects the maximum denoising steps during + inference. + upcast_attention (bool, optional, defaults to False): + If true, upcasts the attention mechanism dimensions for potentially improved performance. + norm_type (str, optional, defaults to "ada_norm_zero"): + Specifies the type of normalization used, can be 'ada_norm_zero'. + norm_elementwise_affine (bool, optional, defaults to False): + If true, enables element-wise affine parameters in the normalization layers. + norm_eps (float, optional, defaults to 1e-6): + A small constant added to the denominator in normalization layers to prevent division by zero. + interpolation_scale (int, optional): Scale factor to use during interpolating the position embeddings. + use_additional_conditions (bool, optional): If we're using additional conditions as inputs. + attention_type (str, optional, defaults to "default"): Kind of attention mechanism to be used. + caption_channels (int, optional, defaults to None): + Number of channels to use for projecting the caption embeddings. + use_linear_projection (bool, optional, defaults to False): + Deprecated argument. Will be removed in a future version. + num_vector_embeds (bool, optional, defaults to False): + Deprecated argument. Will be removed in a future version. + """ + + _supports_gradient_checkpointing = True + _no_split_modules = ["BasicTransformerBlock", "PatchEmbed"] + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 72, + in_channels: int = 4, + out_channels: Optional[int] = 8, + num_layers: int = 28, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = 1152, + attention_bias: bool = True, + sample_size: int = 128, + patch_size: int = 2, + activation_fn: str = "gelu-approximate", + num_embeds_ada_norm: Optional[int] = 1000, + upcast_attention: bool = False, + norm_type: str = "ada_norm_single", + norm_elementwise_affine: bool = False, + norm_eps: float = 1e-6, + interpolation_scale: Optional[int] = None, + use_additional_conditions: Optional[bool] = None, + caption_channels: Optional[int] = None, + attention_type: Optional[str] = "default", + ): + super().__init__() + + # Validate inputs. + if norm_type != "ada_norm_single": + raise NotImplementedError( + f"Forward pass is not implemented when `patch_size` is not None and `norm_type` is '{norm_type}'." + ) + elif norm_type == "ada_norm_single" and num_embeds_ada_norm is None: + raise ValueError( + f"When using a `patch_size` and this `norm_type` ({norm_type}), `num_embeds_ada_norm` cannot be None." + ) + + # Set some common variables used across the board. + self.attention_head_dim = attention_head_dim + self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim + self.out_channels = in_channels if out_channels is None else out_channels + if use_additional_conditions is None: + if sample_size == 128: + use_additional_conditions = True + else: + use_additional_conditions = False + self.use_additional_conditions = use_additional_conditions + + self.gradient_checkpointing = False + + # 2. Initialize the position embedding and transformer blocks. + self.height = self.config.sample_size + self.width = self.config.sample_size + + interpolation_scale = ( + self.config.interpolation_scale + if self.config.interpolation_scale is not None + else max(self.config.sample_size // 64, 1) + ) + self.pos_embed = PatchEmbed( + height=self.config.sample_size, + width=self.config.sample_size, + patch_size=self.config.patch_size, + in_channels=self.config.in_channels, + embed_dim=self.inner_dim, + interpolation_scale=interpolation_scale, + ) + + self.transformer_blocks = nn.CellList( + [ + BasicTransformerBlock( + self.inner_dim, + self.config.num_attention_heads, + self.config.attention_head_dim, + dropout=self.config.dropout, + cross_attention_dim=self.config.cross_attention_dim, + activation_fn=self.config.activation_fn, + num_embeds_ada_norm=self.config.num_embeds_ada_norm, + attention_bias=self.config.attention_bias, + upcast_attention=self.config.upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=self.config.norm_elementwise_affine, + norm_eps=self.config.norm_eps, + attention_type=self.config.attention_type, + ) + for _ in range(self.config.num_layers) + ] + ) + + # 3. Output blocks. + self.norm_out = LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6) + self.scale_shift_table = ms.Parameter( + ops.randn(2, self.inner_dim) / self.inner_dim**0.5, name="scale_shift_table" + ) + self.proj_out = nn.Dense(self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels) + + self.adaln_single = AdaLayerNormSingle(self.inner_dim, use_additional_conditions=self.use_additional_conditions) + self.caption_projection = None + if self.config.caption_channels is not None: + self.caption_projection = PixArtAlphaTextProjection( + in_features=self.config.caption_channels, hidden_size=self.inner_dim + ) + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def construct( + self, + hidden_states: ms.Tensor, + encoder_hidden_states: Optional[ms.Tensor] = None, + timestep: Optional[ms.Tensor] = None, + added_cond_kwargs: Dict[str, ms.Tensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + attention_mask: Optional[ms.Tensor] = None, + encoder_attention_mask: Optional[ms.Tensor] = None, + return_dict: bool = False, + ): + """ + The [`PixArtTransformer2DModel`] forward method. + + Args: + hidden_states (`ms.Tensor` of shape `(batch size, channel, height, width)`): + Input `hidden_states`. + encoder_hidden_states (`ms.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep (`ms.Tensor`, *optional*): + Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. + added_cond_kwargs: (`Dict[str, Any]`, *optional*): Additional conditions to be used as inputs. + cross_attention_kwargs ( `Dict[str, Any]`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + attention_mask ( `ms.Tensor`, *optional*): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + encoder_attention_mask ( `ms.Tensor`, *optional*): + Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: + + * Mask `(batch, sequence_length)` True = keep, False = discard. + * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. + + If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format + above. This bias will be added to the cross-attention scores. + return_dict (`bool`, *optional*, defaults to `False`): + Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + if self.use_additional_conditions and added_cond_kwargs is None: + raise ValueError("`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`.") + + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. + # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. + # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None and attention_mask.ndim == 2: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 1. Input + batch_size = hidden_states.shape[0] + height, width = ( + hidden_states.shape[-2] // self.config["patch_size"], + hidden_states.shape[-1] // self.config["patch_size"], + ) + hidden_states = self.pos_embed(hidden_states) + + timestep, embedded_timestep = self.adaln_single( + timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype + ) + + if self.caption_projection is not None: + encoder_hidden_states = self.caption_projection(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1]) + + # 2. Blocks + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=None, + ) + + # 3. Output + shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, axis=1) + hidden_states = self.norm_out(hidden_states) + # Modulation + hidden_states = hidden_states * (1 + scale) + shift + hidden_states = self.proj_out(hidden_states) + if hidden_states.shape[1] == 1: + hidden_states = hidden_states.squeeze(1) + + # unpatchify + hidden_states = hidden_states.reshape( + -1, height, width, self.config["patch_size"], self.config["patch_size"], self.out_channels + ) + # hidden_states = ops.einsum("nhwpqc->nchpwq", hidden_states) + hidden_states = hidden_states.transpose(0, 5, 1, 3, 2, 4) + output = hidden_states.reshape( + -1, self.out_channels, height * self.config["patch_size"], width * self.config["patch_size"] + ) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/mindone/diffusers/models/transformers/prior_transformer.py b/mindone/diffusers/models/transformers/prior_transformer.py index 1de28c0482..c4aab87c35 100644 --- a/mindone/diffusers/models/transformers/prior_transformer.py +++ b/mindone/diffusers/models/transformers/prior_transformer.py @@ -257,13 +257,13 @@ def construct( attention_mask (`ms.Tensor` of shape `(batch_size, num_embeddings)`): Text mask for the text embeddings. return_dict (`bool`, *optional*, defaults to `False`): - Whether or not to return a [`~models.prior_transformer.PriorTransformerOutput`] instead of a plain - tuple. + Whether or not to return a [`~models.transformers.prior_transformer.PriorTransformerOutput`] instead of + a plain tuple. Returns: - [`~models.prior_transformer.PriorTransformerOutput`] or `tuple`: - If return_dict is True, a [`~models.prior_transformer.PriorTransformerOutput`] is returned, otherwise a - tuple is returned where the first element is the sample tensor. + [`~models.transformers.prior_transformer.PriorTransformerOutput`] or `tuple`: + If return_dict is True, a [`~models.transformers.prior_transformer.PriorTransformerOutput`] is + returned, otherwise a tuple is returned where the first element is the sample tensor. """ batch_size = hidden_states.shape[0] diff --git a/mindone/diffusers/models/transformers/transformer_2d.py b/mindone/diffusers/models/transformers/transformer_2d.py index 738ed7d61c..f544d9ee87 100644 --- a/mindone/diffusers/models/transformers/transformer_2d.py +++ b/mindone/diffusers/models/transformers/transformer_2d.py @@ -11,37 +11,30 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass from typing import Any, Dict, Optional import mindspore as ms from mindspore import nn, ops -from ...configuration_utils import ConfigMixin, register_to_config -from ...utils import BaseOutput, deprecate, logging +from ...configuration_utils import LegacyConfigMixin, register_to_config +from ...utils import deprecate, logging from ..attention import BasicTransformerBlock from ..embeddings import ImagePositionalEmbeddings, PatchEmbed, PixArtAlphaTextProjection -from ..modeling_utils import ModelMixin +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import LegacyModelMixin from ..normalization import AdaLayerNormSingle, GroupNorm, LayerNorm logger = logging.get_logger(__name__) # pylint: disable=invalid-name -@dataclass -class Transformer2DModelOutput(BaseOutput): - """ - The output of [`Transformer2DModel`]. - - Args: - sample (`ms.Tensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): # noqa: E501 - The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability - distributions for the unnoised latent pixels. - """ - - sample: ms.Tensor +class Transformer2DModelOutput(Transformer2DModelOutput): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `Transformer2DModelOutput` from `diffusers.models.transformer_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.modeling_outputs import Transformer2DModelOutput`, instead." # noqa: E501 + deprecate("Transformer2DModelOutput", "1.0.0", deprecation_message) + super().__init__(*args, **kwargs) -class Transformer2DModel(ModelMixin, ConfigMixin): +class Transformer2DModel(LegacyModelMixin, LegacyConfigMixin): """ A 2D Transformer model for image-like data. @@ -70,6 +63,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin): """ _supports_gradient_checkpointing = True + _no_split_modules = ["BasicTransformerBlock"] @register_to_config def __init__( @@ -98,8 +92,11 @@ def __init__( attention_type: str = "default", caption_channels: int = None, interpolation_scale: float = None, + use_additional_conditions: Optional[bool] = None, ): super().__init__() + + # Validate inputs. if patch_size is not None: if norm_type not in ["ada_norm", "ada_norm_zero", "ada_norm_single"]: raise NotImplementedError( @@ -110,31 +107,12 @@ def __init__( f"When using a `patch_size` and this `norm_type` ({norm_type}), `num_embeds_ada_norm` cannot be None." ) - self.use_linear_projection = use_linear_projection - self.num_attention_heads = num_attention_heads - self.attention_head_dim = attention_head_dim - inner_dim = num_attention_heads * attention_head_dim - - conv_cls = nn.Conv2d - linear_cls = nn.Dense - # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` # noqa: E501 # Define whether input is continuous or discrete depending on configuration self.is_input_continuous = (in_channels is not None) and (patch_size is None) self.is_input_vectorized = num_vector_embeds is not None self.is_input_patches = in_channels is not None and patch_size is not None - if norm_type == "layer_norm" and num_embeds_ada_norm is not None: - deprecation_message = ( - f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or" - " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config." - " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect" - " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it" - " would be very nice if you could open a Pull request for the `transformer/config.json` file" - ) - deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False) - norm_type = "ada_norm" - if self.is_input_continuous and self.is_input_vectorized: raise ValueError( f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make" @@ -151,104 +129,202 @@ def __init__( f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None." ) - # 2. Define input layers - if self.is_input_continuous: - self.in_channels = in_channels + if norm_type == "layer_norm" and num_embeds_ada_norm is not None: + deprecation_message = ( + f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or" + " incorrectly set to `'layer_norm'`. Make sure to set `norm_type` to `'ada_norm'` in the config." + " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect" + " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it" + " would be very nice if you could open a Pull request for the `transformer/config.json` file" + ) + deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False) + norm_type = "ada_norm" + + # Set some common variables used across the board. + self.use_linear_projection = use_linear_projection + self.interpolation_scale = interpolation_scale + self.caption_channels = caption_channels + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels - self.norm = GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) - if use_linear_projection: - self.proj_in = linear_cls(in_channels, inner_dim) + if use_additional_conditions is None: + if norm_type == "ada_norm_single" and sample_size == 128: + use_additional_conditions = True else: - self.proj_in = conv_cls(in_channels, inner_dim, kernel_size=1, stride=1, padding=0, has_bias=True) + use_additional_conditions = False + self.use_additional_conditions = use_additional_conditions + + # 2. Initialize the right blocks. + # These functions follow a common structure: + # a. Initialize the input blocks. b. Initialize the transformer blocks. + # c. Initialize the output blocks and other projection blocks when necessary. + if self.is_input_continuous: + self._init_continuous_input(norm_type=norm_type) elif self.is_input_vectorized: - assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" - assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed" + self._init_vectorized_inputs(norm_type=norm_type) + elif self.is_input_patches: + self._init_patched_inputs(norm_type=norm_type) - self.height = sample_size - self.width = sample_size - self.num_vector_embeds = num_vector_embeds - self.num_latent_pixels = self.height * self.width + # Move here to call `gradient_checkpointing.setter` after self.transformer_blocks initiated + self._gradient_checkpointing = False - self.latent_image_embedding = ImagePositionalEmbeddings( - num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width + def _init_continuous_input(self, norm_type): + self.norm = GroupNorm( + num_groups=self.config.norm_num_groups, num_channels=self.in_channels, eps=1e-6, affine=True + ) + if self.use_linear_projection: + self.proj_in = nn.Dense(self.in_channels, self.inner_dim) + else: + self.proj_in = nn.Conv2d( + self.in_channels, self.inner_dim, kernel_size=1, stride=1, pad_mode="pad", padding=0, has_bias=True ) - elif self.is_input_patches: - assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size" - self.height = sample_size - self.width = sample_size + self.transformer_blocks = nn.CellList( + [ + BasicTransformerBlock( + self.inner_dim, + self.config.num_attention_heads, + self.config.attention_head_dim, + dropout=self.config.dropout, + cross_attention_dim=self.config.cross_attention_dim, + activation_fn=self.config.activation_fn, + num_embeds_ada_norm=self.config.num_embeds_ada_norm, + attention_bias=self.config.attention_bias, + only_cross_attention=self.config.only_cross_attention, + double_self_attention=self.config.double_self_attention, + upcast_attention=self.config.upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=self.config.norm_elementwise_affine, + norm_eps=self.config.norm_eps, + attention_type=self.config.attention_type, + ) + for _ in range(self.config.num_layers) + ] + ) - self.patch_size = patch_size - interpolation_scale = ( - interpolation_scale if interpolation_scale is not None else max(self.config.sample_size // 64, 1) - ) - self.pos_embed = PatchEmbed( - height=sample_size, - width=sample_size, - patch_size=patch_size, - in_channels=in_channels, - embed_dim=inner_dim, - interpolation_scale=interpolation_scale, + if self.use_linear_projection: + self.proj_out = nn.Dense(self.inner_dim, self.out_channels) + else: + self.proj_out = nn.Conv2d( + self.inner_dim, self.out_channels, kernel_size=1, stride=1, pad_mode="pad", padding=0, has_bias=True ) - # 3. Define transformers blocks + def _init_vectorized_inputs(self, norm_type): + assert self.config.sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" + assert ( + self.config.num_vector_embeds is not None + ), "Transformer2DModel over discrete input must provide num_embed" + + self.height = self.config.sample_size + self.width = self.config.sample_size + self.num_latent_pixels = self.height * self.width + + self.latent_image_embedding = ImagePositionalEmbeddings( + num_embed=self.config.num_vector_embeds, embed_dim=self.inner_dim, height=self.height, width=self.width + ) + self.transformer_blocks = nn.CellList( [ BasicTransformerBlock( - inner_dim, - num_attention_heads, - attention_head_dim, - dropout=dropout, - cross_attention_dim=cross_attention_dim, - activation_fn=activation_fn, - num_embeds_ada_norm=num_embeds_ada_norm, - attention_bias=attention_bias, - only_cross_attention=only_cross_attention, - double_self_attention=double_self_attention, - upcast_attention=upcast_attention, + self.inner_dim, + self.config.num_attention_heads, + self.config.attention_head_dim, + dropout=self.config.dropout, + cross_attention_dim=self.config.cross_attention_dim, + activation_fn=self.config.activation_fn, + num_embeds_ada_norm=self.config.num_embeds_ada_norm, + attention_bias=self.config.attention_bias, + only_cross_attention=self.config.only_cross_attention, + double_self_attention=self.config.double_self_attention, + upcast_attention=self.config.upcast_attention, norm_type=norm_type, - norm_elementwise_affine=norm_elementwise_affine, - norm_eps=norm_eps, - attention_type=attention_type, + norm_elementwise_affine=self.config.norm_elementwise_affine, + norm_eps=self.config.norm_eps, + attention_type=self.config.attention_type, ) - for d in range(num_layers) + for _ in range(self.config.num_layers) ] ) - # 4. Define output layers - self.out_channels = in_channels if out_channels is None else out_channels - if self.is_input_continuous: - # TODO: should use out_channels for continuous projections - if use_linear_projection: - self.proj_out = linear_cls(inner_dim, in_channels) - else: - self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0, has_bias=True) - elif self.is_input_vectorized: - self.norm_out = LayerNorm(inner_dim) - self.out = nn.Dense(inner_dim, self.num_vector_embeds - 1) - elif self.is_input_patches and norm_type != "ada_norm_single": - self.norm_out = LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) - self.proj_out_1 = nn.Dense(inner_dim, 2 * inner_dim) - self.proj_out_2 = nn.Dense(inner_dim, patch_size * patch_size * self.out_channels) - elif self.is_input_patches and norm_type == "ada_norm_single": - self.norm_out = LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) - self.scale_shift_table = ms.Parameter(ops.randn(2, inner_dim) / inner_dim**0.5, name="scale_shift_table") - self.proj_out = nn.Dense(inner_dim, patch_size * patch_size * self.out_channels) - - # 5. PixArt-Alpha blocks. + self.norm_out = LayerNorm(self.inner_dim) + self.out = nn.Dense(self.inner_dim, self.config.num_vector_embeds - 1) + + def _init_patched_inputs(self, norm_type): + assert self.config.sample_size is not None, "Transformer2DModel over patched input must provide sample_size" + + self.height = self.config.sample_size + self.width = self.config.sample_size + + self.patch_size = self.config.patch_size + interpolation_scale = ( + self.config.interpolation_scale + if self.config.interpolation_scale is not None + else max(self.config.sample_size // 64, 1) + ) + self.pos_embed = PatchEmbed( + height=self.config.sample_size, + width=self.config.sample_size, + patch_size=self.config.patch_size, + in_channels=self.in_channels, + embed_dim=self.inner_dim, + interpolation_scale=interpolation_scale, + ) + + self.transformer_blocks = nn.CellList( + [ + BasicTransformerBlock( + self.inner_dim, + self.config.num_attention_heads, + self.config.attention_head_dim, + dropout=self.config.dropout, + cross_attention_dim=self.config.cross_attention_dim, + activation_fn=self.config.activation_fn, + num_embeds_ada_norm=self.config.num_embeds_ada_norm, + attention_bias=self.config.attention_bias, + only_cross_attention=self.config.only_cross_attention, + double_self_attention=self.config.double_self_attention, + upcast_attention=self.config.upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=self.config.norm_elementwise_affine, + norm_eps=self.config.norm_eps, + attention_type=self.config.attention_type, + ) + for _ in range(self.config.num_layers) + ] + ) + + if self.config.norm_type != "ada_norm_single": + self.norm_out = LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out_1 = nn.Dense(self.inner_dim, 2 * self.inner_dim) + self.proj_out_2 = nn.Dense( + self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels + ) + elif self.config.norm_type == "ada_norm_single": + self.norm_out = LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6) + self.scale_shift_table = ms.Parameter( + ops.randn(2, self.inner_dim) / self.inner_dim**0.5, name="scale_shift_table" + ) + self.proj_out = nn.Dense( + self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels + ) + + # PixArt-Alpha blocks. self.adaln_single = None - self.use_additional_conditions = False - if norm_type == "ada_norm_single": - self.use_additional_conditions = self.config.sample_size == 128 + if self.config.norm_type == "ada_norm_single": # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use # additional conditions until we find better name - self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions) + self.adaln_single = AdaLayerNormSingle( + self.inner_dim, use_additional_conditions=self.use_additional_conditions + ) self.caption_projection = None - if caption_channels is not None: - self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim) - - self._gradient_checkpointing = False + if self.caption_channels is not None: + self.caption_projection = PixArtAlphaTextProjection( + in_features=self.caption_channels, hidden_size=self.inner_dim + ) def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): @@ -311,8 +387,8 @@ def construct( tuple. Returns: - If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a - `tuple` where the first element is the sample tensor. + If `return_dict` is True, an [`~models.transformers.transformer_2d.Transformer2DModelOutput`] is returned, + otherwise a `tuple` where the first element is the sample tensor. """ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. @@ -339,42 +415,21 @@ def construct( # 1. Input # define variables outside to fool ai compiler - embedded_timestep, batch, inner_dim, height, width, residual = (None,) * 6 + embedded_timestep, batch_size, inner_dim, height, width, residual = (None,) * 6 if self.is_input_continuous: - batch, _, height, width = hidden_states.shape + batch_size, _, height, width = hidden_states.shape residual = hidden_states - - hidden_states = self.norm(hidden_states) - if not self.use_linear_projection: - hidden_states = self.proj_in(hidden_states) - inner_dim = hidden_states.shape[1] - hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) - else: - inner_dim = hidden_states.shape[1] - hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) - hidden_states = self.proj_in(hidden_states) - + hidden_states, inner_dim = self._operate_on_continuous_inputs(hidden_states) elif self.is_input_vectorized: hidden_states = self.latent_image_embedding(hidden_states) elif self.is_input_patches: height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size - hidden_states = self.pos_embed(hidden_states) - - if self.adaln_single is not None: - if self.use_additional_conditions and added_cond_kwargs is None: - raise ValueError( - "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`." - ) - timestep, embedded_timestep = self.adaln_single( - timestep, added_cond_kwargs, batch_size=hidden_states.shape[0], hidden_dtype=hidden_states.dtype - ) + hidden_states, encoder_hidden_states, timestep, embedded_timestep = self._operate_on_patched_inputs( + hidden_states, encoder_hidden_states, timestep, added_cond_kwargs + ) # 2. Blocks - if self.caption_projection is not None: - encoder_hidden_states = self.caption_projection(encoder_hidden_states) - encoder_hidden_states = encoder_hidden_states.view(hidden_states.shape[0], -1, hidden_states.shape[-1]) - for block in self.transformer_blocks: hidden_states = block( hidden_states, @@ -389,51 +444,110 @@ def construct( # 3. Output output = None if self.is_input_continuous: - if not self.use_linear_projection: - hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2) - hidden_states = self.proj_out(hidden_states) - else: - hidden_states = self.proj_out(hidden_states) - hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2) - - output = hidden_states + residual + output = self._get_output_for_continuous_inputs( + hidden_states=hidden_states, + residual=residual, + batch_size=batch_size, + height=height, + width=width, + inner_dim=inner_dim, + ) elif self.is_input_vectorized: - hidden_states = self.norm_out(hidden_states) - logits = self.out(hidden_states) - # (batch, self.num_vector_embeds - 1, self.num_latent_pixels) - logits = logits.permute(0, 2, 1) - - # log(p(x_0)) - # ops.log_softmax doesn't support double precision. why assume float output - output = ops.log_softmax(logits.float(), axis=1).to(hidden_states.dtype) - - if self.is_input_patches: - if self.config["norm_type"] != "ada_norm_single": - conditioning = self.transformer_blocks[0].norm1.emb( - timestep, class_labels, hidden_dtype=hidden_states.dtype - ) - shift, scale = self.proj_out_1(ops.silu(conditioning)).chunk(2, axis=1) - hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] - hidden_states = self.proj_out_2(hidden_states) - elif self.config["norm_type"] == "ada_norm_single": - shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, axis=1) - hidden_states = self.norm_out(hidden_states) - # Modulation - hidden_states = hidden_states * (1 + scale) + shift - hidden_states = self.proj_out(hidden_states) - if hidden_states.shape[1] == 1: - hidden_states = hidden_states.squeeze(1) - - # unpatchify - if self.adaln_single is None: - height = width = int(hidden_states.shape[1] ** 0.5) - hidden_states = hidden_states.reshape( - -1, height, width, self.patch_size, self.patch_size, self.out_channels + output = self._get_output_for_vectorized_inputs(hidden_states) + elif self.is_input_patches: + output = self._get_output_for_patched_inputs( + hidden_states=hidden_states, + timestep=timestep, + class_labels=class_labels, + embedded_timestep=embedded_timestep, + height=height, + width=width, ) - hidden_states = hidden_states.transpose(0, 5, 1, 3, 2, 4) - output = hidden_states.reshape(-1, self.out_channels, height * self.patch_size, width * self.patch_size) if not return_dict: return (output,) return Transformer2DModelOutput(sample=output) + + def _operate_on_continuous_inputs(self, hidden_states): + batch, _, height, width = hidden_states.shape + hidden_states = self.norm(hidden_states) + + if not self.use_linear_projection: + hidden_states = self.proj_in(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) + else: + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) + hidden_states = self.proj_in(hidden_states) + + return hidden_states, inner_dim + + def _operate_on_patched_inputs(self, hidden_states, encoder_hidden_states, timestep, added_cond_kwargs): + batch_size = hidden_states.shape[0] + hidden_states = self.pos_embed(hidden_states) + embedded_timestep = None + + if self.adaln_single is not None: + if self.use_additional_conditions and added_cond_kwargs is None: + raise ValueError( + "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`." + ) + timestep, embedded_timestep = self.adaln_single( + timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype + ) + + if self.caption_projection is not None: + encoder_hidden_states = self.caption_projection(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1]) + + return hidden_states, encoder_hidden_states, timestep, embedded_timestep + + def _get_output_for_continuous_inputs(self, hidden_states, residual, batch_size, height, width, inner_dim): + if not self.use_linear_projection: + hidden_states = hidden_states.reshape(batch_size, height, width, inner_dim).permute(0, 3, 1, 2) + hidden_states = self.proj_out(hidden_states) + else: + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.reshape(batch_size, height, width, inner_dim).permute(0, 3, 1, 2) + + output = hidden_states + residual + return output + + def _get_output_for_vectorized_inputs(self, hidden_states): + hidden_states = self.norm_out(hidden_states) + logits = self.out(hidden_states) + # (batch, self.num_vector_embeds - 1, self.num_latent_pixels) + logits = logits.permute(0, 2, 1) + # log(p(x_0)) + output = ops.log_softmax(logits.float(), axis=1).to(hidden_states.dtype) + return output + + def _get_output_for_patched_inputs( + self, hidden_states, timestep, class_labels, embedded_timestep, height=None, width=None + ): + if self.config.norm_type != "ada_norm_single": + conditioning = self.transformer_blocks[0].norm1.emb( + timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + shift, scale = self.proj_out_1(ops.silu(conditioning)).chunk(2, axis=1) + hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] + hidden_states = self.proj_out_2(hidden_states) + elif self.config.norm_type == "ada_norm_single": + shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, axis=1) + hidden_states = self.norm_out(hidden_states) + # Modulation + hidden_states = hidden_states * (1 + scale) + shift + hidden_states = self.proj_out(hidden_states) + if hidden_states.shape[1] == 1: + hidden_states = hidden_states.squeeze(1) + + # unpatchify + if self.adaln_single is None: + height = width = int(hidden_states.shape[1] ** 0.5) + hidden_states = hidden_states.reshape(-1, height, width, self.patch_size, self.patch_size, self.out_channels) + # hidden_states = ops.einsum("nhwpqc->nchpwq", hidden_states) + hidden_states = hidden_states.transpose(0, 5, 1, 3, 2, 4) + output = hidden_states.reshape(-1, self.out_channels, height * self.patch_size, width * self.patch_size) + return output diff --git a/mindone/diffusers/models/transformers/transformer_sd3.py b/mindone/diffusers/models/transformers/transformer_sd3.py index 2fcec79acb..66383d5417 100644 --- a/mindone/diffusers/models/transformers/transformer_sd3.py +++ b/mindone/diffusers/models/transformers/transformer_sd3.py @@ -1,4 +1,4 @@ -# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved. +# Copyright 2024 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,13 +13,13 @@ # limitations under the License. -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, List, Optional, Union import mindspore as ms from mindspore import nn from ...configuration_utils import ConfigMixin, register_to_config -from ...loaders import PeftAdapterMixin +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...models.attention import JointTransformerBlock from ...models.attention_processor import AttentionProcessor from ...models.modeling_utils import ModelMixin @@ -31,7 +31,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): +class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): """ The Transformer model introduced in Stable Diffusion 3. @@ -186,6 +186,7 @@ def construct( encoder_hidden_states: ms.Tensor = None, pooled_projections: ms.Tensor = None, timestep: ms.Tensor = None, + block_controlnet_hidden_states: List = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = False, ) -> Union[ms.Tensor, Transformer2DModelOutput]: @@ -201,6 +202,8 @@ def construct( from the embeddings of input conditions. timestep ( `ms.Tensor`): Used to indicate denoising step. + block_controlnet_hidden_states: (`list` of `torch.Tensor`): + A list of tensors that if specified are added to the residuals of transformer blocks. joint_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in @@ -213,12 +216,12 @@ def construct( If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a `tuple` where the first element is the sample tensor. """ - if joint_attention_kwargs is not None and "lora_scale" in joint_attention_kwargs: + if joint_attention_kwargs is not None and "scale" in joint_attention_kwargs: # weight the lora layers by setting `lora_scale` for each PEFT layer here # and remove `lora_scale` from each PEFT layer at the end. # scale_lora_layers & unscale_lora_layers maybe contains some operation forbidden in graph mode raise RuntimeError( - f"You are trying to set scaling of lora layer by passing {joint_attention_kwargs['lora_scale']=}. " + f"You are trying to set scaling of lora layer by passing {joint_attention_kwargs['scale']=}. " f"However it's not allowed in on-the-fly model forwarding. " f"Please manually call `scale_lora_layers(model, lora_scale)` before model forwarding and " f"`unscale_lora_layers(model, lora_scale)` after model forwarding. " @@ -231,11 +234,16 @@ def construct( temb = self.time_text_embed(timestep, pooled_projections) encoder_hidden_states = self.context_embedder(encoder_hidden_states) - for block in self.transformer_blocks: + for index_block, block in enumerate(self.transformer_blocks): encoder_hidden_states, hidden_states = block( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb ) + # controlnet residual + if block_controlnet_hidden_states is not None and block.context_pre_only is False: + interval_control = len(self.transformer_blocks) // len(block_controlnet_hidden_states) + hidden_states = hidden_states + block_controlnet_hidden_states[index_block // interval_control] + hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.proj_out(hidden_states) diff --git a/mindone/diffusers/models/transformers/transformer_temporal.py b/mindone/diffusers/models/transformers/transformer_temporal.py index 74a51e87ff..29aeea0800 100644 --- a/mindone/diffusers/models/transformers/transformer_temporal.py +++ b/mindone/diffusers/models/transformers/transformer_temporal.py @@ -151,13 +151,14 @@ def construct( `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain - tuple. + Whether or not to return a [`~models.transformers.transformer_temporal.TransformerTemporalModelOutput`] + instead of a plain tuple. Returns: - [`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`: - If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is - returned, otherwise a `tuple` where the first element is the sample tensor. + [`~models.transformers.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`: + If `return_dict` is True, an + [`~models.transformers.transformer_temporal.TransformerTemporalModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. """ # 1. Input batch_frames, channel, height, width = hidden_states.shape @@ -293,13 +294,14 @@ def construct( A tensor indicating whether the input contains only images. 1 indicates that the input contains only images, 0 indicates that the input contains video frames. return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.transformer_temporal.TransformerTemporalModelOutput`] instead of a plain - tuple. + Whether or not to return a [`~models.transformers.transformer_temporal.TransformerTemporalModelOutput`] + instead of a plain tuple. Returns: - [`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`: - If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is - returned, otherwise a `tuple` where the first element is the sample tensor. + [`~models.transformers.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`: + If `return_dict` is True, an + [`~models.transformers.transformer_temporal.TransformerTemporalModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. """ # 1. Input batch_frames, _, height, width = hidden_states.shape @@ -310,10 +312,10 @@ def construct( time_context_first_timestep = time_context[None, :].reshape(batch_size, num_frames, -1, time_context.shape[-1])[ :, 0 ] - time_context = time_context_first_timestep[None, :].broadcast_to( - (height * width, batch_size, 1, time_context.shape[-1]) + time_context = time_context_first_timestep[:, None].broadcast_to( + (batch_size, height * width, time_context.shape[-2], time_context.shape[-1]) ) - time_context = time_context.reshape(height * width * batch_size, 1, time_context.shape[-1]) + time_context = time_context.reshape(batch_size * height * width, -1, time_context.shape[-1]) residual = hidden_states diff --git a/mindone/diffusers/models/unets/unet_1d.py b/mindone/diffusers/models/unets/unet_1d.py index 489c834c8e..94d830364a 100644 --- a/mindone/diffusers/models/unets/unet_1d.py +++ b/mindone/diffusers/models/unets/unet_1d.py @@ -207,11 +207,11 @@ def construct( The noisy input tensor with the following shape `(batch_size, num_channels, sample_size)`. timestep (`ms.Tensor` or `float` or `int`): The number of timesteps to denoise an input. return_dict (`bool`, *optional*, defaults to `False`): - Whether or not to return a [`~models.unet_1d.UNet1DOutput`] instead of a plain tuple. + Whether or not to return a [`~models.unets.unet_1d.UNet1DOutput`] instead of a plain tuple. Returns: - [`~models.unet_1d.UNet1DOutput`] or `tuple`: - If `return_dict` is True, an [`~models.unet_1d.UNet1DOutput`] is returned, otherwise a `tuple` is + [`~models.unets.unet_1d.UNet1DOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unets.unet_1d.UNet1DOutput`] is returned, otherwise a `tuple` is returned where the first element is the sample tensor. """ diff --git a/mindone/diffusers/models/unets/unet_2d.py b/mindone/diffusers/models/unets/unet_2d.py index 0278480847..f0b92fcaa2 100644 --- a/mindone/diffusers/models/unets/unet_2d.py +++ b/mindone/diffusers/models/unets/unet_2d.py @@ -269,11 +269,11 @@ def construct( class_labels (`ms.Tensor`, *optional*, defaults to `None`): Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. return_dict (`bool`, *optional*, defaults to `False`): - Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple. + Whether or not to return a [`~models.unets.unet_2d.UNet2DOutput`] instead of a plain tuple. Returns: - [`~models.unet_2d.UNet2DOutput`] or `tuple`: - If `return_dict` is True, an [`~models.unet_2d.UNet2DOutput`] is returned, otherwise a `tuple` is + [`~models.unets.unet_2d.UNet2DOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unets.unet_2d.UNet2DOutput`] is returned, otherwise a `tuple` is returned where the first element is the sample tensor. """ # 0. center input if necessary diff --git a/mindone/diffusers/models/unets/unet_2d_blocks.py b/mindone/diffusers/models/unets/unet_2d_blocks.py index 8e50ae9812..6d2b60131d 100644 --- a/mindone/diffusers/models/unets/unet_2d_blocks.py +++ b/mindone/diffusers/models/unets/unet_2d_blocks.py @@ -751,6 +751,7 @@ def __init__( self, in_channels: int, temb_channels: int, + out_channels: Optional[int] = None, dropout: float = 0.0, num_layers: int = 1, transformer_layers_per_block: Union[int, Tuple[int]] = 1, @@ -758,6 +759,7 @@ def __init__( resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, + resnet_groups_out: Optional[int] = None, resnet_pre_norm: bool = True, num_attention_heads: int = 1, output_scale_factor: float = 1.0, @@ -769,6 +771,10 @@ def __init__( ): super().__init__() + out_channels = out_channels or in_channels + self.in_channels = in_channels + self.out_channels = out_channels + self.has_cross_attention = True self.has_motion_modules = False self.num_attention_heads = num_attention_heads @@ -778,14 +784,17 @@ def __init__( if isinstance(transformer_layers_per_block, int): transformer_layers_per_block = [transformer_layers_per_block] * num_layers + resnet_groups_out = resnet_groups_out or resnet_groups + # there is always at least one resnet resnets = [ ResnetBlock2D( in_channels=in_channels, - out_channels=in_channels, + out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, + groups_out=resnet_groups_out, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, @@ -800,11 +809,11 @@ def __init__( attentions.append( Transformer2DModel( num_attention_heads, - in_channels // num_attention_heads, - in_channels=in_channels, + out_channels // num_attention_heads, + in_channels=out_channels, num_layers=transformer_layers_per_block[i], cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, + norm_num_groups=resnet_groups_out, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, attention_type=attention_type, @@ -814,8 +823,8 @@ def __init__( attentions.append( DualTransformer2DModel( num_attention_heads, - in_channels // num_attention_heads, - in_channels=in_channels, + out_channels // num_attention_heads, + in_channels=out_channels, num_layers=1, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, @@ -823,11 +832,11 @@ def __init__( ) resnets.append( ResnetBlock2D( - in_channels=in_channels, - out_channels=in_channels, + in_channels=out_channels, + out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, - groups=resnet_groups, + groups=resnet_groups_out, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, diff --git a/mindone/diffusers/models/unets/unet_2d_condition.py b/mindone/diffusers/models/unets/unet_2d_condition.py index 5c437cd786..d91e38f2e9 100644 --- a/mindone/diffusers/models/unets/unet_2d_condition.py +++ b/mindone/diffusers/models/unets/unet_2d_condition.py @@ -19,6 +19,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin +from ...loaders.single_file_model import FromOriginalModelMixin from ...utils import BaseOutput, logging from ..activations import get_activation from ..attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor @@ -54,7 +55,9 @@ class UNet2DConditionOutput(BaseOutput): sample: ms.Tensor = None -class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin): +class UNet2DConditionModel( + ModelMixin, ConfigMixin, FromOriginalModelMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin +): r""" A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample shaped output. @@ -95,13 +98,13 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, The dimension of the cross attention features. transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1): The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for - [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], - [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. - reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None): + [`~models.unets.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unets.unet_2d_blocks.CrossAttnUpBlock2D`], + [`~models.unets.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. + reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None): The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for - [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], - [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. + [`~models.unets.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unets.unet_2d_blocks.CrossAttnUpBlock2D`], + [`~models.unets.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. encoder_hid_dim (`int`, *optional*, defaults to None): If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` dimension to `cross_attention_dim`. @@ -149,6 +152,7 @@ class conditioning with `class_embed_type` equal to `None`. """ _supports_gradient_checkpointing = True + _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D"] @register_to_config def __init__( @@ -587,7 +591,7 @@ def _set_encoder_hid_proj( elif encoder_hid_dim_type == "text_image_proj": # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use - # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)` + # case when `addition_embed_type == "text_image_proj"` (Kandinsky 2.1)` self.encoder_hid_proj = TextImageProjection( text_embed_dim=encoder_hid_dim, image_embed_dim=cross_attention_dim, @@ -667,7 +671,7 @@ def _set_add_embedding( elif addition_embed_type == "text_image": # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use - # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)` + # case when `addition_embed_type == "text_image"` (Kandinsky 2.1)` self.add_embedding = TextImageTimeEmbedding( text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim ) @@ -688,7 +692,7 @@ def _set_pos_net_if_use_gligen(self, attention_type: str, cross_attention_dim: i positive_len = 768 if isinstance(cross_attention_dim, int): positive_len = cross_attention_dim - elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list): + elif isinstance(cross_attention_dim, (list, tuple)): positive_len = cross_attention_dim[0] feature_type = "text-only" if attention_type == "gated" else "text-image" @@ -872,7 +876,7 @@ def process_encoder_hidden_states( if self.encoder_hid_proj is not None and self.encoder_hid_dim_type == "text_proj": encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) elif self.encoder_hid_proj is not None and self.encoder_hid_dim_type == "text_image_proj": - # Kadinsky 2.1 - style + # Kandinsky 2.1 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" # noqa: E501 @@ -955,8 +959,8 @@ def construct( Returns: [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: - If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise - a `tuple` is returned where the first element is the sample tensor. + If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned, + otherwise a `tuple` is returned where the first element is the sample tensor. """ # By default samples have to be AT least a multiple of the overall upsampling factor. # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). diff --git a/mindone/diffusers/models/unets/unet_3d_blocks.py b/mindone/diffusers/models/unets/unet_3d_blocks.py index b2d63c4c38..981ea43315 100644 --- a/mindone/diffusers/models/unets/unet_3d_blocks.py +++ b/mindone/diffusers/models/unets/unet_3d_blocks.py @@ -110,6 +110,7 @@ def get_down_block( raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockMotion") return CrossAttnDownBlockMotion( num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -244,6 +245,7 @@ def get_up_block( raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockMotion") return CrossAttnUpBlockMotion( num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, diff --git a/mindone/diffusers/models/unets/unet_3d_condition.py b/mindone/diffusers/models/unets/unet_3d_condition.py index 177f0136a3..eeba5492e1 100644 --- a/mindone/diffusers/models/unets/unet_3d_condition.py +++ b/mindone/diffusers/models/unets/unet_3d_condition.py @@ -83,6 +83,8 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) cross_attention_dim (`int`, *optional*, defaults to 1024): The dimension of the cross attention features. attention_head_dim (`int`, *optional*, defaults to 64): The dimension of the attention heads. num_attention_heads (`int`, *optional*): The number of attention heads. + time_cond_proj_dim (`int`, *optional*, defaults to `None`): + The dimension of `cond_proj` layer in the timestep embedding. """ _supports_gradient_checkpointing = False @@ -115,6 +117,7 @@ def __init__( cross_attention_dim: int = 1024, attention_head_dim: Union[int, Tuple[int]] = 64, num_attention_heads: Optional[Union[int, Tuple[int]]] = None, + time_cond_proj_dim: Optional[int] = None, ): super().__init__() @@ -177,6 +180,7 @@ def __init__( timestep_input_dim, time_embed_dim, act_fn=act_fn, + cond_proj_dim=time_cond_proj_dim, ) self.transformer_in = TransformerTemporalModel( @@ -374,10 +378,10 @@ def fn_recursive_feed_forward(module: nn.Cell, chunk_size: int, dim: int): if hasattr(module, "set_chunk_feed_forward"): module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) - for child in module.name_cells().items(): + for child in module.name_cells().values(): fn_recursive_feed_forward(child, chunk_size, dim) - for module in self.name_cells().items(): + for module in self.name_cells().values(): fn_recursive_feed_forward(module, chunk_size, dim) def disable_forward_chunking(self): @@ -385,10 +389,10 @@ def fn_recursive_feed_forward(module: nn.Cell, chunk_size: int, dim: int): if hasattr(module, "set_chunk_feed_forward"): module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) - for child in module.name_cells().items(): + for child in module.name_cells().values(): fn_recursive_feed_forward(child, chunk_size, dim) - for module in self.name_cells().items(): + for module in self.name_cells().values(): fn_recursive_feed_forward(module, None, 0) # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor @@ -449,15 +453,15 @@ def construct( mid_block_additional_residual: (`ms.Tensor`, *optional*): A tensor that if specified is added to the residual of the middle unet block. return_dict (`bool`, *optional*, defaults to `False`): - Whether or not to return a [`~models.unet_3d_condition.UNet3DConditionOutput`] instead of a plain + Whether or not to return a [`~models.unets.unet_3d_condition.UNet3DConditionOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. Returns: - [`~models.unet_3d_condition.UNet3DConditionOutput`] or `tuple`: - If `return_dict` is True, an [`~models.unet_3d_condition.UNet3DConditionOutput`] is returned, otherwise - a `tuple` is returned where the first element is the sample tensor. + [`~models.unets.unet_3d_condition.UNet3DConditionOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unets.unet_3d_condition.UNet3DConditionOutput`] is returned, + otherwise a `tuple` is returned where the first element is the sample tensor. """ # By default samples have to be AT least a multiple of the overall upsampling factor. # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). diff --git a/mindone/diffusers/models/unets/unet_i2vgen_xl.py b/mindone/diffusers/models/unets/unet_i2vgen_xl.py index ff9e0f4f4c..a198f052c9 100644 --- a/mindone/diffusers/models/unets/unet_i2vgen_xl.py +++ b/mindone/diffusers/models/unets/unet_i2vgen_xl.py @@ -92,8 +92,8 @@ def construct( class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): r""" - I2VGenXL UNet. It is a conditional 3D UNet model that takes a noisy sample, conditional state, and a timestep - and returns a sample-shaped output. + I2VGenXL UNet. It is a conditional 3D UNet model that takes a noisy sample, conditional state, and a timestep and + returns a sample-shaped output. This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving). @@ -403,10 +403,10 @@ def fn_recursive_feed_forward(module: nn.Cell, chunk_size: int, dim: int): if hasattr(module, "set_chunk_feed_forward"): module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) - for child in module.name_cells().items(): + for child in module.name_cells().values(): fn_recursive_feed_forward(child, chunk_size, dim) - for module in self.name_cells().items(): + for module in self.name_cells().values(): fn_recursive_feed_forward(module, chunk_size, dim) # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking @@ -415,10 +415,10 @@ def fn_recursive_feed_forward(module: nn.Cell, chunk_size: int, dim: int): if hasattr(module, "set_chunk_feed_forward"): module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) - for child in module.name_cells().items(): + for child in module.name_cells().values(): fn_recursive_feed_forward(child, chunk_size, dim) - for module in self.name_cells().items(): + for module in self.name_cells().values(): fn_recursive_feed_forward(module, None, 0) # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor @@ -461,7 +461,8 @@ def construct( timestep (`ms.Tensor` or `float` or `int`): The number of timesteps to denoise an input. fps (`ms.Tensor`): Frames per second for the video being generated. Used as a "micro-condition". image_latents (`ms.Tensor`): Image encodings from the VAE. - image_embeddings (`ms.Tensor`): Projection embeddings of the conditioning image computed with a vision encoder. + image_embeddings (`ms.Tensor`): + Projection embeddings of the conditioning image computed with a vision encoder. encoder_hidden_states (`ms.Tensor`): The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. cross_attention_kwargs (`dict`, *optional*): @@ -469,13 +470,13 @@ def construct( `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). return_dict (`bool`, *optional*, defaults to `False`): - Whether or not to return a [`~models.unet_3d_condition.UNet3DConditionOutput`] instead of a plain + Whether or not to return a [`~models.unets.unet_3d_condition.UNet3DConditionOutput`] instead of a plain tuple. Returns: - [`~models.unet_3d_condition.UNet3DConditionOutput`] or `tuple`: - If `return_dict` is True, an [`~models.unet_3d_condition.UNet3DConditionOutput`] is returned, otherwise - a `tuple` is returned where the first element is the sample tensor. + [`~models.unets.unet_3d_condition.UNet3DConditionOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unets.unet_3d_condition.UNet3DConditionOutput`] is returned, + otherwise a `tuple` is returned where the first element is the sample tensor. """ batch_size, channels, num_frames, height, width = sample.shape diff --git a/mindone/diffusers/models/unets/unet_motion_model.py b/mindone/diffusers/models/unets/unet_motion_model.py index fb97d9897f..adcb48a4ed 100644 --- a/mindone/diffusers/models/unets/unet_motion_model.py +++ b/mindone/diffusers/models/unets/unet_motion_model.py @@ -16,10 +16,10 @@ import mindspore as ms from mindspore import nn, ops -from ...configuration_utils import ConfigMixin, register_to_config +from ...configuration_utils import ConfigMixin, FrozenDict, register_to_config from ...loaders import UNet2DConditionLoadersMixin from ...utils import logging -from ..attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor +from ..attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor, IPAdapterAttnProcessor from ..embeddings import TimestepEmbedding, Timesteps from ..modeling_utils import ModelMixin from ..normalization import GroupNorm @@ -206,6 +206,8 @@ def __init__( norm_num_groups: int = 32, norm_eps: float = 1e-5, cross_attention_dim: int = 1280, + transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, + reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None, use_linear_projection: bool = False, num_attention_heads: Union[int, Tuple[int, ...]] = 8, motion_max_seq_length: int = 32, @@ -213,6 +215,9 @@ def __init__( use_motion_mid_block: int = True, encoder_hid_dim: Optional[int] = None, encoder_hid_dim_type: Optional[str] = None, + addition_embed_type: Optional[str] = None, + addition_time_embed_dim: Optional[int] = None, + projection_class_embeddings_input_dim: Optional[int] = None, time_cond_proj_dim: Optional[int] = None, ): super().__init__() @@ -238,6 +243,21 @@ def __init__( f"`num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." ) + if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." # noqa: E501 + ) + + if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." # noqa: E501 + ) + + if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None: + for layer_number_per_block in transformer_layers_per_block: + if isinstance(layer_number_per_block, list): + raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.") + # input conv_in_kernel = 3 conv_out_kernel = 3 @@ -263,6 +283,10 @@ def __init__( if encoder_hid_dim_type is None: self.encoder_hid_proj = None + if addition_embed_type == "text_time": + self.add_time_proj = Timesteps(addition_time_embed_dim, True, 0) + self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + # class embedding down_blocks = [] up_blocks = [] @@ -270,6 +294,15 @@ def __init__( if isinstance(num_attention_heads, int): num_attention_heads = (num_attention_heads,) * len(down_block_types) + if isinstance(cross_attention_dim, int): + cross_attention_dim = (cross_attention_dim,) * len(down_block_types) + + if isinstance(layers_per_block, int): + layers_per_block = [layers_per_block] * len(down_block_types) + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + # down output_channel = block_out_channels[0] for i, down_block_type in enumerate(down_block_types): @@ -279,7 +312,7 @@ def __init__( down_block = get_down_block( down_block_type, - num_layers=layers_per_block, + num_layers=layers_per_block[i], in_channels=input_channel, out_channels=output_channel, temb_channels=time_embed_dim, @@ -287,46 +320,51 @@ def __init__( resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, - cross_attention_dim=cross_attention_dim, + cross_attention_dim=cross_attention_dim[i], num_attention_heads=num_attention_heads[i], downsample_padding=downsample_padding, use_linear_projection=use_linear_projection, dual_cross_attention=False, temporal_num_attention_heads=motion_num_attention_heads, temporal_max_seq_length=motion_max_seq_length, + transformer_layers_per_block=transformer_layers_per_block[i], ) down_blocks.append(down_block) self.down_blocks = nn.CellList(down_blocks) - # mid + # mid: only definition, binding attribute to UNetMotionModel later to maintain the order of sub-modules within + # UNetMotionModel as self.down_blocks -> self.up_blocks -> self.mid_block, ensuring the correct sequence of + # sub-modules is loaded when the ip-adpater is loaded. if use_motion_mid_block: - self.mid_block = UNetMidBlockCrossAttnMotion( + mid_block = UNetMidBlockCrossAttnMotion( in_channels=block_out_channels[-1], temb_channels=time_embed_dim, resnet_eps=norm_eps, resnet_act_fn=act_fn, output_scale_factor=mid_block_scale_factor, - cross_attention_dim=cross_attention_dim, + cross_attention_dim=cross_attention_dim[-1], num_attention_heads=num_attention_heads[-1], resnet_groups=norm_num_groups, dual_cross_attention=False, use_linear_projection=use_linear_projection, temporal_num_attention_heads=motion_num_attention_heads, temporal_max_seq_length=motion_max_seq_length, + transformer_layers_per_block=transformer_layers_per_block[-1], ) else: - self.mid_block = UNetMidBlock2DCrossAttn( + mid_block = UNetMidBlock2DCrossAttn( in_channels=block_out_channels[-1], temb_channels=time_embed_dim, resnet_eps=norm_eps, resnet_act_fn=act_fn, output_scale_factor=mid_block_scale_factor, - cross_attention_dim=cross_attention_dim, + cross_attention_dim=cross_attention_dim[-1], num_attention_heads=num_attention_heads[-1], resnet_groups=norm_num_groups, dual_cross_attention=False, use_linear_projection=use_linear_projection, + transformer_layers_per_block=transformer_layers_per_block[-1], ) # count how many layers upsample the images @@ -336,6 +374,9 @@ def __init__( layers_per_resnet_in_up_blocks = [] reversed_block_out_channels = list(reversed(block_out_channels)) reversed_num_attention_heads = list(reversed(num_attention_heads)) + reversed_layers_per_block = list(reversed(layers_per_block)) + reversed_cross_attention_dim = list(reversed(cross_attention_dim)) + reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) output_channel = reversed_block_out_channels[0] for i, up_block_type in enumerate(up_block_types): @@ -354,7 +395,7 @@ def __init__( up_block = get_up_block( up_block_type, - num_layers=layers_per_block + 1, + num_layers=reversed_layers_per_block[i] + 1, in_channels=input_channel, out_channels=output_channel, prev_output_channel=prev_output_channel, @@ -363,13 +404,14 @@ def __init__( resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, - cross_attention_dim=cross_attention_dim, + cross_attention_dim=reversed_cross_attention_dim[i], num_attention_heads=reversed_num_attention_heads[i], dual_cross_attention=False, resolution_idx=i, use_linear_projection=use_linear_projection, temporal_num_attention_heads=motion_num_attention_heads, temporal_max_seq_length=motion_max_seq_length, + transformer_layers_per_block=reversed_transformer_layers_per_block[i], ) up_blocks.append(up_block) prev_output_channel = output_channel @@ -377,6 +419,9 @@ def __init__( self.up_blocks = nn.CellList(up_blocks) self.layers_per_resnet_in_up_blocks = layers_per_resnet_in_up_blocks + # bind mid_block to self here + self.mid_block = mid_block + # out if norm_num_groups is not None: self.conv_norm_out = GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) @@ -405,7 +450,7 @@ def from_unet2d( has_motion_adapter = motion_adapter is not None # based on https://github.com/guoyww/AnimateDiff/blob/895f3220c06318ea0760131ec70408b466c49333/animatediff/models/unet.py#L459 - config = unet.config + config = dict(unet.config) config["_class_name"] = cls.__name__ down_blocks = [] @@ -438,6 +483,7 @@ def from_unet2d( if not config.get("num_attention_heads"): config["num_attention_heads"] = config["attention_head_dim"] + config = FrozenDict(config) model = cls.from_config(config) # Move dtype conversion code here to avoid dtype mismatch issues when loading weights @@ -459,6 +505,27 @@ def from_unet2d( ms.load_param_into_net(model.time_proj, unet.time_proj.parameters_dict()) ms.load_param_into_net(model.time_embedding, unet.time_embedding.parameters_dict()) + if any(isinstance(proc, IPAdapterAttnProcessor) for proc in unet.attn_processors.values()): + attn_procs = {} + for name, processor in unet.attn_processors.items(): + if name.endswith("attn1.processor"): + attn_processor_class = AttnProcessor + attn_procs[name] = attn_processor_class() + else: + attn_processor_class = IPAdapterAttnProcessor + attn_procs[name] = attn_processor_class( + hidden_size=processor.hidden_size, + cross_attention_dim=processor.cross_attention_dim, + scale=processor.scale, + num_tokens=processor.num_tokens, + ) + for name, processor in model.attn_processors.items(): + if name not in attn_procs: + attn_procs[name] = processor.__class__() + model.set_attn_processor(attn_procs) + model.config.encoder_hid_dim_type = "ip_image_proj" + model.encoder_hid_proj = unet.encoder_hid_proj + for i, down_block in enumerate(unet.down_blocks): ms.load_param_into_net(model.down_blocks[i].resnets, down_block.resnets.parameters_dict()) if hasattr(model.down_blocks[i], "attentions"): @@ -639,10 +706,10 @@ def fn_recursive_feed_forward(module: nn.Cell, chunk_size: int, dim: int): if hasattr(module, "set_chunk_feed_forward"): module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) - for child in module.name_cells().items(): + for child in module.name_cells().values(): fn_recursive_feed_forward(child, chunk_size, dim) - for module in self.name_cells().items(): + for module in self.name_cells().values(): fn_recursive_feed_forward(module, chunk_size, dim) # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking @@ -651,10 +718,10 @@ def fn_recursive_feed_forward(module: nn.Cell, chunk_size: int, dim: int): if hasattr(module, "set_chunk_feed_forward"): module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) - for child in module.name_cells().items(): + for child in module.name_cells().values(): fn_recursive_feed_forward(child, chunk_size, dim) - for module in self.name_cells().items(): + for module in self.name_cells().values(): fn_recursive_feed_forward(module, None, 0) # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor @@ -713,13 +780,13 @@ def construct( mid_block_additional_residual: (`ms.Tensor`, *optional*): A tensor that if specified is added to the residual of the middle unet block. return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.unet_3d_condition.UNet3DConditionOutput`] instead of a plain + Whether or not to return a [`~models.unets.unet_3d_condition.UNet3DConditionOutput`] instead of a plain tuple. Returns: - [`~models.unet_3d_condition.UNet3DConditionOutput`] or `tuple`: - If `return_dict` is True, an [`~models.unet_3d_condition.UNet3DConditionOutput`] is returned, otherwise - a `tuple` is returned where the first element is the sample tensor. + [`~models.unets.unet_3d_condition.UNet3DConditionOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unets.unet_3d_condition.UNet3DConditionOutput`] is returned, + otherwise a `tuple` is returned where the first element is the sample tensor. """ # By default samples have to be AT least a multiple of the overall upsampling factor. # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). @@ -763,6 +830,28 @@ def construct( t_emb = t_emb.to(dtype=self.dtype) emb = self.time_embedding(t_emb, timestep_cond) + aug_emb = None + + if self.config["addition_embed_type"] == "text_time": + if "text_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" # noqa: E501 + ) + + text_embeds = added_cond_kwargs.get("text_embeds") + if "time_ids" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" # noqa: E501 + ) + time_ids = added_cond_kwargs.get("time_ids") + time_embeds = self.add_time_proj(time_ids.flatten()).to(text_embeds.dtype) + time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) + + add_embeds = ops.concat([text_embeds, time_embeds], axis=-1) + add_embeds = add_embeds.to(emb.dtype) + aug_emb = self.add_embedding(add_embeds) + + emb = emb if aug_emb is None else emb + aug_emb emb = emb.repeat_interleave(repeats=num_frames, dim=0) encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0) diff --git a/mindone/diffusers/models/unets/unet_spatio_temporal_condition.py b/mindone/diffusers/models/unets/unet_spatio_temporal_condition.py index be2a51f6cc..065316c088 100644 --- a/mindone/diffusers/models/unets/unet_spatio_temporal_condition.py +++ b/mindone/diffusers/models/unets/unet_spatio_temporal_condition.py @@ -31,8 +31,8 @@ class UNetSpatioTemporalConditionOutput(BaseOutput): class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): r""" - A conditional Spatio-Temporal UNet model that takes a noisy video frames, conditional state, and a timestep and returns a sample - shaped output. + A conditional Spatio-Temporal UNet model that takes a noisy video frames, conditional state, and a timestep and + returns a sample shaped output. This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving). @@ -59,8 +59,9 @@ class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionL The dimension of the cross attention features. transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1): The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for - [`~models.unet_3d_blocks.CrossAttnDownBlockSpatioTemporal`], [`~models.unet_3d_blocks.CrossAttnUpBlockSpatioTemporal`], - [`~models.unet_3d_blocks.UNetMidBlockSpatioTemporal`]. + [`~models.unets.unet_3d_blocks.CrossAttnDownBlockSpatioTemporal`], + [`~models.unets.unet_3d_blocks.CrossAttnUpBlockSpatioTemporal`], + [`~models.unets.unet_3d_blocks.UNetMidBlockSpatioTemporal`]. num_attention_heads (`int`, `Tuple[int]`, defaults to `(5, 10, 10, 20)`): The number of attention heads. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. @@ -355,10 +356,10 @@ def fn_recursive_feed_forward(module: nn.Cell, chunk_size: int, dim: int): if hasattr(module, "set_chunk_feed_forward"): module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) - for child in module.name_cells().items(): + for child in module.name_cells().values(): fn_recursive_feed_forward(child, chunk_size, dim) - for module in self.name_cells().items(): + for module in self.name_cells().values(): fn_recursive_feed_forward(module, chunk_size, dim) def construct( @@ -382,12 +383,12 @@ def construct( The additional time ids with shape `(batch, num_additional_ids)`. These are encoded with sinusoidal embeddings and added to the time embeddings. return_dict (`bool`, *optional*, defaults to `False`): - Whether or not to return a [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] instead of a plain - tuple. + Whether or not to return a [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] instead + of a plain tuple. Returns: [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] or `tuple`: - If `return_dict` is True, an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is returned, otherwise - a `tuple` is returned where the first element is the sample tensor. + If `return_dict` is True, an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is + returned, otherwise a `tuple` is returned where the first element is the sample tensor. """ # 1. time timesteps = timestep diff --git a/mindone/diffusers/models/unets/unet_stable_cascade.py b/mindone/diffusers/models/unets/unet_stable_cascade.py index 035007f36f..dbf4f11d76 100644 --- a/mindone/diffusers/models/unets/unet_stable_cascade.py +++ b/mindone/diffusers/models/unets/unet_stable_cascade.py @@ -23,8 +23,7 @@ from mindspore.common.initializer import Constant, Normal, XavierNormal, initializer from ...configuration_utils import ConfigMixin, register_to_config - -# from ...loaders.unet import FromOriginalUNetMixin +from ...loaders import FromOriginalModelMixin from ...utils import BaseOutput from ..attention_processor import Attention from ..modeling_utils import ModelMixin @@ -45,6 +44,7 @@ def construct(self, x): class SDCascadeTimestepBlock(nn.Cell): def __init__(self, c, c_timestep, conds=[]): super().__init__() + self.mapper = nn.Dense(c_timestep, c * 2) self.conds = conds for cname in conds: @@ -99,12 +99,11 @@ def construct(self, x): class SDCascadeAttnBlock(nn.Cell): def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0): super().__init__() - linear_cls = nn.Dense self.self_attn = self_attn self.norm = SDCascadeLayerNorm(c, elementwise_affine=False, eps=1e-6) self.attention = Attention(query_dim=c, heads=nhead, dim_head=c // nhead, dropout=dropout, bias=True) - self.kv_mapper = nn.SequentialCell(nn.SiLU(), linear_cls(c_cond, c)) + self.kv_mapper = nn.SequentialCell(nn.SiLU(), nn.Dense(c_cond, c)) def construct(self, x, kv): kv = self.kv_mapper(kv) @@ -140,8 +139,7 @@ class StableCascadeUNetOutput(BaseOutput): sample: ms.Tensor = None -# class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalUNetMixin): -class StableCascadeUNet(ModelMixin, ConfigMixin): +class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalModelMixin): _supports_gradient_checkpointing = True @register_to_config @@ -193,7 +191,8 @@ def __init__( block_out_channels (Tuple[int], defaults to (2048, 2048)): Tuple of output channels for each block. num_attention_heads (Tuple[int], defaults to (32, 32)): - Number of attention heads in each attention block. Set to -1 to if block types in a layer do not have attention. + Number of attention heads in each attention block. Set to -1 to if block types in a layer do not have + attention. down_num_layers_per_block (Tuple[int], defaults to [8, 24]): Number of layers in each down block. up_num_layers_per_block (Tuple[int], defaults to [24, 8]): @@ -204,10 +203,9 @@ def __init__( Number of 1x1 Convolutional layers to repeat in each up block. block_types_per_layer (Tuple[Tuple[str]], optional, defaults to ( - ("SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"), - ("SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock") - ): - Block types used in each layer of the up/down blocks. + ("SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"), ("SDCascadeResBlock", + "SDCascadeTimestepBlock", "SDCascadeAttnBlock") + ): Block types used in each layer of the up/down blocks. clip_text_in_channels (`int`, *optional*, defaults to `None`): Number of input channels for CLIP based text conditioning. clip_text_pooled_in_channels (`int`, *optional*, defaults to 1280): @@ -562,7 +560,9 @@ def _up_decode(self, level_outputs, r_embed, clip): if isinstance(block, SDCascadeResBlock): skip = level_outputs[i] if k == 0 and i > 0 else None if skip is not None and (x.shape[-1] != skip.shape[-1] or x.shape[-2] != skip.shape[-2]): + orig_type = x.dtype x = ops.interpolate(x.float(), skip.shape[-2:], mode="bilinear", align_corners=True) + x = x.to(orig_type) x = block(x, skip) elif isinstance(block, SDCascadeAttnBlock): x = block(x, clip) diff --git a/mindone/diffusers/models/upsampling.py b/mindone/diffusers/models/upsampling.py index 271ed5cef8..626ee38ad0 100644 --- a/mindone/diffusers/models/upsampling.py +++ b/mindone/diffusers/models/upsampling.py @@ -110,7 +110,6 @@ def __init__( self.use_conv_transpose = use_conv_transpose self.name = name self.interpolate = interpolate - conv_cls = nn.Conv2d if norm_type == "ln_norm": self.norm = LayerNorm(channels, eps, elementwise_affine) @@ -137,7 +136,7 @@ def __init__( elif use_conv: if kernel_size is None: kernel_size = 3 - conv = conv_cls( + conv = nn.Conv2d( self.channels, self.out_channels, kernel_size=kernel_size, diff --git a/mindone/diffusers/models/vq_model.py b/mindone/diffusers/models/vq_model.py index acff1c73f6..f219db5319 100644 --- a/mindone/diffusers/models/vq_model.py +++ b/mindone/diffusers/models/vq_model.py @@ -11,164 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass -from typing import Optional, Tuple, Union +from ..utils import deprecate +from .autoencoders.vq_model import VQEncoderOutput, VQModel -import mindspore as ms -from mindspore import nn -from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput -from .autoencoders.vae import Decoder, DecoderOutput, Encoder, VectorQuantizer -from .modeling_utils import ModelMixin +class VQEncoderOutput(VQEncoderOutput): + deprecation_message = "Importing `VQEncoderOutput` from `diffusers.models.vq_model` is deprecated and this will be removed in a future version. Please use `from diffusers.models.autoencoders.vq_model import VQEncoderOutput`, instead." # noqa: E501 + deprecate("VQEncoderOutput", "0.31", deprecation_message) -@dataclass -class VQEncoderOutput(BaseOutput): - """ - Output of VQModel encoding method. - - Args: - latents (`ms.Tensor` of shape `(batch_size, num_channels, height, width)`): - The encoded output sample from the last layer of the model. - """ - - latents: ms.Tensor - - -class VQModel(ModelMixin, ConfigMixin): - r""" - A VQ-VAE model for decoding latent representations. - - This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented - for all models (such as downloading or saving). - - Parameters: - in_channels (int, *optional*, defaults to 3): Number of channels in the input image. - out_channels (int, *optional*, defaults to 3): Number of channels in the output. - down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`): - Tuple of downsample block types. - up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`): - Tuple of upsample block types. - block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`): - Tuple of block output channels. - layers_per_block (`int`, *optional*, defaults to `1`): Number of layers per block. - act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. - latent_channels (`int`, *optional*, defaults to `3`): Number of channels in the latent space. - sample_size (`int`, *optional*, defaults to `32`): Sample input size. - num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE. - norm_num_groups (`int`, *optional*, defaults to `32`): Number of groups for normalization layers. - vq_embed_dim (`int`, *optional*): Hidden dim of codebook vectors in the VQ-VAE. - scaling_factor (`float`, *optional*, defaults to `0.18215`): - The component-wise standard deviation of the trained latent space computed using the first batch of the - training set. This is used to scale the latent space to have unit variance when training the diffusion - model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the - diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 - / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image - Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. - norm_type (`str`, *optional*, defaults to `"group"`): - Type of normalization layer to use. Can be one of `"group"` or `"spatial"`. - """ - - @register_to_config - def __init__( - self, - in_channels: int = 3, - out_channels: int = 3, - down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",), - up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",), - block_out_channels: Tuple[int, ...] = (64,), - layers_per_block: int = 1, - act_fn: str = "silu", - latent_channels: int = 3, - sample_size: int = 32, - num_vq_embeddings: int = 256, - norm_num_groups: int = 32, - vq_embed_dim: Optional[int] = None, - scaling_factor: float = 0.18215, - norm_type: str = "group", # group, spatial - mid_block_add_attention=True, - lookup_from_codebook=False, - force_upcast=False, - ): - super().__init__() - - # pass init params to Encoder - self.encoder = Encoder( - in_channels=in_channels, - out_channels=latent_channels, - down_block_types=down_block_types, - block_out_channels=block_out_channels, - layers_per_block=layers_per_block, - act_fn=act_fn, - norm_num_groups=norm_num_groups, - double_z=False, - mid_block_add_attention=mid_block_add_attention, - ) - - vq_embed_dim = vq_embed_dim if vq_embed_dim is not None else latent_channels - - self.quant_conv = nn.Conv2d(latent_channels, vq_embed_dim, 1, has_bias=True) - self.quantize = VectorQuantizer(num_vq_embeddings, vq_embed_dim, beta=0.25, remap=None, sane_index_shape=False) - self.post_quant_conv = nn.Conv2d(vq_embed_dim, latent_channels, 1, has_bias=True) - - # pass init params to Decoder - self.decoder = Decoder( - in_channels=latent_channels, - out_channels=out_channels, - up_block_types=up_block_types, - block_out_channels=block_out_channels, - layers_per_block=layers_per_block, - act_fn=act_fn, - norm_num_groups=norm_num_groups, - norm_type=norm_type, - mid_block_add_attention=mid_block_add_attention, - ) - - def encode(self, x: ms.Tensor, return_dict: bool = False): - h = self.encoder(x) - h = self.quant_conv(h) - - if not return_dict: - return (h,) - - return VQEncoderOutput(latents=h) - - def decode(self, h: ms.Tensor, force_not_quantize: bool = False, return_dict: bool = False, shape=None): - # also go through quantization layer - if not force_not_quantize: - quant, _, _ = self.quantize(h) - elif self.config["lookup_from_codebook"]: - quant = self.quantize.get_codebook_entry(h, shape) - else: - quant = h - quant2 = self.post_quant_conv(quant) - dec = self.decoder(quant2, quant if self.config["norm_type"] == "spatial" else None) - - if not return_dict: - return (dec,) - - return DecoderOutput(sample=dec) - - def construct(self, sample: ms.Tensor, return_dict: bool = False) -> Union[DecoderOutput, Tuple[ms.Tensor, ...]]: - r""" - The [`VQModel`] forward method. - - Args: - sample (`ms.Tensor`): Input sample. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`models.vq_model.VQEncoderOutput`] instead of a plain tuple. - - Returns: - [`~models.vq_model.VQEncoderOutput`] or `tuple`: - If return_dict is True, a [`~models.vq_model.VQEncoderOutput`] is returned, otherwise a plain `tuple` - is returned. - """ - - h = self.encode(sample)[0] - dec = self.decode(h)[0] - - if not return_dict: - return (dec,) - - return DecoderOutput(sample=dec) +class VQModel(VQModel): + deprecation_message = "Importing `VQModel` from `diffusers.models.vq_model` is deprecated and this will be removed in a future version. Please use `from diffusers.models.autoencoders.vq_model import VQModel`, instead." # noqa: E501 + deprecate("VQModel", "0.31", deprecation_message) diff --git a/mindone/diffusers/pipelines/__init__.py b/mindone/diffusers/pipelines/__init__.py index 7be46685f9..0059056db4 100644 --- a/mindone/diffusers/pipelines/__init__.py +++ b/mindone/diffusers/pipelines/__init__.py @@ -6,6 +6,7 @@ _import_structure = { "animatediff": [ "AnimateDiffPipeline", + "AnimateDiffSDXLPipeline", "AnimateDiffVideoToVideoPipeline", ], "blip_diffusion": ["BlipDiffusionPipeline"], @@ -19,6 +20,13 @@ "StableDiffusionXLControlNetInpaintPipeline", "StableDiffusionXLControlNetPipeline", ], + "controlnet_xs": [ + "StableDiffusionControlNetXSPipeline", + "StableDiffusionXLControlNetXSPipeline", + ], + "controlnet_sd3": [ + "StableDiffusion3ControlNetPipeline", + ], "dance_diffusion": ["DanceDiffusionPipeline"], "ddim": ["DDIMPipeline"], "ddpm": ["DDPMPipeline"], @@ -31,6 +39,7 @@ "IFSuperResolutionPipeline", ], "dit": ["DiTPipeline"], + "hunyuandit": ["HunyuanDiTPipeline"], "i2vgen_xl": ["I2VGenXLPipeline"], "latent_diffusion": ["LDMSuperResolutionPipeline", "LDMTextToImagePipeline"], "kandinsky": [ @@ -62,7 +71,14 @@ "LatentConsistencyModelImg2ImgPipeline", "LatentConsistencyModelPipeline", ], - "pixart_alpha": ["PixArtAlphaPipeline"], + "marigold": [ + "MarigoldDepthPipeline", + "MarigoldNormalsPipeline", + ], + "pixart_alpha": [ + "PixArtAlphaPipeline", + "PixArtSigmaPipeline", + ], "shap_e": ["ShapEImg2ImgPipeline", "ShapEPipeline"], "stable_cascade": [ "StableCascadeCombinedPipeline", @@ -82,6 +98,7 @@ ], "stable_diffusion_3": [ "StableDiffusion3Pipeline", + "StableDiffusion3Img2ImgPipeline", ], "stable_diffusion_gligen": [ "StableDiffusionGLIGENPipeline", @@ -112,7 +129,7 @@ } if TYPE_CHECKING: - from .animatediff import AnimateDiffPipeline, AnimateDiffVideoToVideoPipeline + from .animatediff import AnimateDiffPipeline, AnimateDiffSDXLPipeline, AnimateDiffVideoToVideoPipeline from .blip_diffusion import BlipDiffusionPipeline from .consistency_models import ConsistencyModelPipeline from .controlnet import ( @@ -124,6 +141,8 @@ StableDiffusionXLControlNetInpaintPipeline, StableDiffusionXLControlNetPipeline, ) + from .controlnet_sd3 import StableDiffusion3ControlNetPipeline + from .controlnet_xs import StableDiffusionControlNetXSPipeline, StableDiffusionXLControlNetXSPipeline from .ddim import DDIMPipeline from .ddpm import DDPMPipeline from .deepfloyd_if import ( @@ -135,6 +154,7 @@ IFSuperResolutionPipeline, ) from .dit import DiTPipeline + from .hunyuandit import HunyuanDiTPipeline from .i2vgen_xl import I2VGenXLPipeline from .kandinsky import ( KandinskyCombinedPipeline, @@ -160,8 +180,9 @@ from .kandinsky3 import Kandinsky3Img2ImgPipeline, Kandinsky3Pipeline from .latent_consistency_models import LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline from .latent_diffusion import LDMSuperResolutionPipeline, LDMTextToImagePipeline + from .marigold import MarigoldDepthPipeline, MarigoldNormalsPipeline from .pipeline_utils import DiffusionPipeline, ImagePipelineOutput - from .pixart_alpha import PixArtAlphaPipeline + from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline from .stable_cascade import StableCascadeCombinedPipeline, StableCascadeDecoderPipeline, StableCascadePriorPipeline from .stable_diffusion import ( @@ -175,7 +196,7 @@ StableDiffusionPipeline, StableDiffusionUpscalePipeline, ) - from .stable_diffusion_3 import StableDiffusion3Pipeline + from .stable_diffusion_3 import StableDiffusion3Img2ImgPipeline, StableDiffusion3Pipeline from .stable_diffusion_diffedit import StableDiffusionDiffEditPipeline from .stable_diffusion_gligen import StableDiffusionGLIGENPipeline, StableDiffusionGLIGENTextImagePipeline from .stable_diffusion_xl import ( diff --git a/mindone/diffusers/pipelines/animatediff/__init__.py b/mindone/diffusers/pipelines/animatediff/__init__.py index 282170a4ad..4bd77ba4ee 100644 --- a/mindone/diffusers/pipelines/animatediff/__init__.py +++ b/mindone/diffusers/pipelines/animatediff/__init__.py @@ -7,11 +7,13 @@ _import_structure["pipeline_animatediff"] = ["AnimateDiffPipeline"] +_import_structure["pipeline_animatediff_sdxl"] = ["AnimateDiffSDXLPipeline"] _import_structure["pipeline_animatediff_video2video"] = ["AnimateDiffVideoToVideoPipeline"] if TYPE_CHECKING: from .pipeline_animatediff import AnimateDiffPipeline + from .pipeline_animatediff_sdxl import AnimateDiffSDXLPipeline from .pipeline_animatediff_video2video import AnimateDiffVideoToVideoPipeline from .pipeline_output import AnimateDiffPipelineOutput else: diff --git a/mindone/diffusers/pipelines/animatediff/pipeline_animatediff.py b/mindone/diffusers/pipelines/animatediff/pipeline_animatediff.py index 532a487177..71213ed059 100644 --- a/mindone/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/mindone/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -22,7 +22,7 @@ from mindspore import ops from ....transformers import CLIPTextModel, CLIPVisionModelWithProjection -from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...image_processor import PipelineImageInput from ...loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel from ...models.unets.unet_motion_model import MotionAdapter @@ -36,6 +36,7 @@ ) from ...utils import deprecate, logging, scale_lora_layers, unscale_lora_layers from ...utils.mindspore_utils import randn_tensor +from ...video_processor import VideoProcessor from ..pipeline_utils import DiffusionPipeline from .pipeline_output import AnimateDiffPipelineOutput @@ -81,27 +82,6 @@ """ -def tensor2vid(video: ms.Tensor, processor: "VaeImageProcessor", output_type: str = "np"): - batch_size, channels, num_frames, height, width = video.shape - outputs = [] - for batch_idx in range(batch_size): - batch_vid = video[batch_idx].permute(1, 0, 2, 3) - batch_output = processor.postprocess(batch_vid, output_type) - - outputs.append(batch_output) - - if output_type == "np": - outputs = np.stack(outputs) - - elif output_type == "pt": - outputs = ops.stack(outputs) - - elif not output_type == "pil": - raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pil']") - - return outputs - - class AnimateDiffPipeline( DiffusionPipeline, TextualInversionLoaderMixin, @@ -145,7 +125,7 @@ def __init__( vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, - unet: UNet2DConditionModel, + unet: Union[UNet2DConditionModel, UNetMotionModel], motion_adapter: MotionAdapter, scheduler: Union[ DDIMScheduler, @@ -173,7 +153,7 @@ def __init__( image_encoder=image_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt # noqa: E501 def encode_prompt( @@ -341,9 +321,10 @@ def encode_prompt( negative_prompt_embeds = negative_prompt_embeds.tile((1, num_images_per_prompt, 1)) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - if isinstance(self, LoraLoaderMixin): - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder is not None: + if isinstance(self, LoraLoaderMixin): + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) return prompt_embeds, negative_prompt_embeds @@ -357,17 +338,15 @@ def encode_image(self, image, num_images_per_prompt, output_hidden_states=None): image = image.to(dtype=dtype) if output_hidden_states: - image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True)[2][-2] image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_enc_hidden_states = self.image_encoder( - ops.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] + uncond_image_enc_hidden_states = self.image_encoder(ops.zeros_like(image), output_hidden_states=True)[2][-2] uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( num_images_per_prompt, dim=0 ) return image_enc_hidden_states, uncond_image_enc_hidden_states else: - image_embeds = self.image_encoder(image).image_embeds + image_embeds = self.image_encoder(image)[0] image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = ops.zeros_like(image_embeds) @@ -636,11 +615,11 @@ def __call__( not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. - ip_adapter_image_embeds (`List[ms.Tensor]`, *optional*): - Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters. - Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding - if `do_classifier_free_guidance` is set to `True`. - If not provided, embeddings are computed from the `ip_adapter_image` input argument. + ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated video. Choose between `ms.Tensor`, `PIL.Image` or `np.array`. @@ -782,7 +761,7 @@ def __call__( num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order # 8. Denoising loop - with self.progress_bar(total=num_inference_steps) as progress_bar: + with self.progress_bar(total=self._num_timesteps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = ops.cat([latents] * 2) if self.do_classifier_free_guidance else latents @@ -834,7 +813,7 @@ def __call__( video = latents else: video_tensor = self.decode_latents(latents) - video = tensor2vid(video_tensor, self.image_processor, output_type=output_type) + video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) if not return_dict: return (video,) diff --git a/mindone/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py b/mindone/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py new file mode 100644 index 0000000000..90f0e04c93 --- /dev/null +++ b/mindone/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py @@ -0,0 +1,1206 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +from transformers import CLIPImageProcessor, CLIPTokenizer + +import mindspore as ms +from mindspore import ops + +from ....transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPVisionModelWithProjection +from ...image_processor import PipelineImageInput +from ...loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, +) +from ...models import AutoencoderKL, ImageProjection, MotionAdapter, UNet2DConditionModel, UNetMotionModel +from ...schedulers import ( + DDIMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, +) +from ...utils import logging, scale_lora_layers, unscale_lora_layers +from ...utils.mindspore_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import AnimateDiffPipelineOutput + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import mindspore + >>> from mindone.diffusers.models import AutoencoderKL, MotionAdapter + >>> from mindone.diffusers import AnimateDiffSDXLPipeline, DDIMScheduler + >>> from mindone.diffusers.utils import export_to_gif + + >>> adapter = MotionAdapter.from_pretrained( + ... "a-r-r-o-w/animatediff-motion-adapter-sdxl-beta", mindspore_dtype=mindspore.float16 + ... ) + + >>> vae = AutoencoderKL.from_pretrained( + ... "madebyollin/sdxl-vae-fp16-fix", mindspore_dtype=mindspore.float16 + ... ) + + >>> model_id = "stabilityai/stable-diffusion-xl-base-1.0" + >>> scheduler = DDIMScheduler.from_pretrained( + ... model_id, + ... subfolder="scheduler", + ... clip_sample=False, + ... timestep_spacing="linspace", + ... beta_schedule="linear", + ... steps_offset=1, + ... ) + >>> pipe = AnimateDiffSDXLPipeline.from_pretrained( + ... model_id, + ... vae=vae, + ... motion_adapter=adapter, + ... scheduler=scheduler, + ... mindspore_dtype=mindspore.float16, + ... variant="fp16", + ... ) + + >>> output = pipe( + ... prompt="a panda surfing in the ocean, realistic, high quality", + ... negative_prompt="low quality, worst quality", + ... num_inference_steps=20, + ... guidance_scale=8, + ... width=1024, + ... height=1024, + ... num_frames=16, + ... ) + + >>> frames = output[0][0] + >>> export_to_gif(frames, "animation.gif") + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(axis=tuple(range(1, noise_pred_text.ndim)), keepdims=True) + std_cfg = noise_cfg.std(axis=tuple(range(1, noise_cfg.ndim)), keepdims=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[ms.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and + the second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class AnimateDiffSDXLPipeline( + DiffusionPipeline, + FromSingleFileMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, + IPAdapterMixin, +): + r""" + Pipeline for text-to-video generation using Stable Diffusion XL. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion XL uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([` CLIPTextModelWithProjection`]): + Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`CLIPTokenizer`): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): + Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): + Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of + `stabilityai/stable-diffusion-xl-base-1-0`. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae" + _optional_components = [ + "tokenizer", + "tokenizer_2", + "text_encoder", + "text_encoder_2", + "image_encoder", + "feature_extractor", + ] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "add_text_embeds", + "add_time_ids", + "negative_pooled_prompt_embeds", + "negative_add_time_ids", + ] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: Union[UNet2DConditionModel, UNetMotionModel], + motion_adapter: MotionAdapter, + scheduler: Union[ + DDIMScheduler, + PNDMScheduler, + LMSDiscreteScheduler, + EulerDiscreteScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + ], + image_encoder: CLIPVisionModelWithProjection = None, + feature_extractor: CLIPImageProcessor = None, + force_zeros_for_empty_prompt: bool = True, + ): + super().__init__() + + if isinstance(unet, UNet2DConditionModel): + unet = UNetMotionModel.from_unet2d(unet, motion_adapter) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + motion_adapter=motion_adapter, + scheduler=scheduler, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + ) + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor) + + self.default_sample_size = self.unet.config.sample_size + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt with num_images_per_prompt->num_videos_per_prompt # noqa E501 + def encode_prompt( + self, + prompt: str, + prompt_2: Optional[str] = None, + num_videos_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[str] = None, + negative_prompt_2: Optional[str] = None, + prompt_embeds: Optional[ms.Tensor] = None, + negative_prompt_embeds: Optional[ms.Tensor] = None, + pooled_prompt_embeds: Optional[ms.Tensor] = None, + negative_pooled_prompt_embeds: Optional[ms.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + num_videos_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`ms.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`ms.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`ms.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`ms.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # textual inversion: process multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="np", + ) + + text_input_ids = ms.Tensor.from_numpy(text_inputs.input_ids) + untruncated_ids = ms.Tensor.from_numpy( + tokenizer(prompt, padding="longest", return_tensors="np").input_ids + ) + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not ops.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(text_input_ids, output_hidden_states=True) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + if clip_skip is None: + prompt_embeds = prompt_embeds[2][-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds[2][-(clip_skip + 2)] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = ops.concat(prompt_embeds_list, axis=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = ops.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = ops.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + + uncond_tokens: List[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="np", + ) + + negative_prompt_embeds = text_encoder( + ms.Tensor.from_numpy(uncond_input.input_ids), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds[2][-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = ops.concat(negative_prompt_embeds_list, axis=-1) + + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.tile((1, num_videos_per_prompt, 1)) + prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype) + + negative_prompt_embeds = negative_prompt_embeds.tile((1, num_videos_per_prompt, 1)) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.tile((1, num_videos_per_prompt)).view( + bs_embed * num_videos_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.tile((1, num_videos_per_prompt)).view( + bs_embed * num_videos_per_prompt, -1 + ) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin): + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin): + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.get_parameters()).dtype + + if not isinstance(image, ms.Tensor): + image = self.feature_extractor(image, return_tensors="np").pixel_values + image = ms.Tensor(image) + + image = image.to(dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True)[2][-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder(ops.zeros_like(image), output_hidden_states=True)[2][-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image)[0] + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = ops.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, num_images_per_prompt, do_classifier_free_guidance + ): + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." # noqa: E501 + ) + + image_embeds = [] + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, 1, output_hidden_state + ) + single_image_embeds = ops.stack([single_image_embeds] * num_images_per_prompt, axis=0) + single_negative_image_embeds = ops.stack([single_negative_image_embeds] * num_images_per_prompt, axis=0) + + if do_classifier_free_guidance: + single_image_embeds = ops.cat([single_negative_image_embeds, single_image_embeds]) + + image_embeds.append(single_image_embeds) + else: + repeat_dims = [1] + image_embeds = [] + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + single_image_embeds = single_image_embeds.tile( + (num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))) + ) + single_negative_image_embeds = single_negative_image_embeds.tile( + (num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))) + ) + single_image_embeds = ops.cat([single_negative_image_embeds, single_image_embeds]) + else: + single_image_embeds = single_image_embeds.tile( + (num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))) + ) + image_embeds.append(single_image_embeds) + + return image_embeds + + # Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents + def decode_latents(self, latents): + latents = 1 / self.vae.config.scaling_factor * latents + + batch_size, channels, num_frames, height, width = latents.shape + latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) + + image = self.vae.decode(latents)[0] + video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + video = video.float() + return video + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + prompt_2, + height, + width, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" # noqa E501 + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." # noqa E501 + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." # noqa E501 + ) + + # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, generator, latents=None + ): + shape = ( + batch_size, + num_channels_latents, + num_frames, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, dtype=dtype) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + ): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_channels + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." # noqa E501 + ) + + add_time_ids = ms.Tensor([add_time_ids], dtype=dtype) + return add_time_ids + + def upcast_vae(self): + self.vae.to(dtype=ms.float32) + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: ms.Tensor, embedding_dim: int = 512, dtype: ms.Type = ms.float32 + ) -> ms.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`ms.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`mindspore.dtype`, *optional*, defaults to `mindspore.float32`): + Data type of the generated embeddings. + + Returns: + `ms.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = ops.log(ms.Tensor(10000.0)) / (half_dim - 1) + emb = ops.exp(ops.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = ops.cat([ops.sin(emb), ops.cos(emb)], axis=1) + if embedding_dim % 2 == 1: # zero pad + emb = ms.mint.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def denoising_end(self): + return self._denoising_end + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + num_frames: int = 16, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + timesteps: List[int] = None, + sigmas: List[float] = None, + denoising_end: Optional[float] = None, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_videos_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[np.random.Generator, List[np.random.Generator]]] = None, + latents: Optional[ms.Tensor] = None, + prompt_embeds: Optional[ms.Tensor] = None, + negative_prompt_embeds: Optional[ms.Tensor] = None, + pooled_prompt_embeds: Optional[ms.Tensor] = None, + negative_pooled_prompt_embeds: Optional[ms.Tensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[ms.Tensor]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = False, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + original_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Optional[Tuple[int, int]] = None, + negative_original_size: Optional[Tuple[int, int]] = None, + negative_crops_coords_top_left: Tuple[int, int] = (0, 0), + negative_target_size: Optional[Tuple[int, int]] = None, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the video generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + num_frames: + The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds + amounts to 2 seconds of video. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated video. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated video. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality video at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise as determined by the discrete timesteps selected by the + scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a + "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower video quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the video generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the video generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`np.random.Generator` or `List[np.random.Generator]`, *optional*): + One or a list of [numpy generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`ms.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`ms.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`ms.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`ms.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`ms.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): + Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[ms.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. If not provided, embeddings are computed from the + `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated video. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `False`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.AnimateDiffPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is + returned, otherwise a `tuple` is returned where the first element is a list with the generated frames. + """ + + # 0. Default height and width to unet + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + num_videos_per_prompt = 1 + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._denoising_end = denoising_end + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + lora_scale = self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + num_videos_per_prompt=num_videos_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=lora_scale, + clip_skip=self.clip_skip, + ) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, timesteps, sigmas) + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + num_frames, + height, + width, + prompt_embeds.dtype, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + if negative_original_size is not None and negative_target_size is not None: + negative_add_time_ids = self._get_add_time_ids( + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + else: + negative_add_time_ids = add_time_ids + + if self.do_classifier_free_guidance: + prompt_embeds = ops.cat([negative_prompt_embeds, prompt_embeds], axis=0) + add_text_embeds = ops.cat([negative_pooled_prompt_embeds, add_text_embeds], axis=0) + add_time_ids = ops.cat([negative_add_time_ids, add_time_ids], axis=0) + + add_time_ids = add_time_ids.tile((batch_size * num_videos_per_prompt, 1)) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + batch_size * num_videos_per_prompt, + self.do_classifier_free_guidance, + ) + + # 7.1 Apply denoising_end + if ( + self.denoising_end is not None + and isinstance(self.denoising_end, float) + and self.denoising_end > 0 + and self.denoising_end < 1 + ): + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (self.denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + # 8. Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = ms.Tensor(self.guidance_scale - 1).tile((batch_size * num_videos_per_prompt)) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(dtype=latents.dtype) + + self._num_timesteps = len(timesteps) + + # 9. Denoising loop + with self.progress_bar(total=self._num_timesteps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = ops.cat([latents] * 2) if self.do_classifier_free_guidance else latents + + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + if ip_adapter_image is not None or ip_adapter_image_embeds: + added_cond_kwargs["image_embeds"] = image_embeds + + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=ms.mutable(added_cond_kwargs), + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) + negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids) + + progress_bar.update() + + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == ms.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.get_parameters())).dtype) + + # 10. Post processing + if output_type == "latent": + video = latents + else: + video_tensor = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=ms.float16) + + if not return_dict: + return (video,) + + return AnimateDiffPipelineOutput(frames=video) diff --git a/mindone/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py b/mindone/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py index 8c10da1742..d79c53659b 100644 --- a/mindone/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +++ b/mindone/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py @@ -22,7 +22,7 @@ from mindspore import ops from ....transformers import CLIPTextModel, CLIPVisionModelWithProjection -from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...image_processor import PipelineImageInput from ...loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel from ...models.unets.unet_motion_model import MotionAdapter @@ -36,6 +36,7 @@ ) from ...utils import logging, scale_lora_layers, unscale_lora_layers from ...utils.mindspore_utils import randn_tensor +from ...video_processor import VideoProcessor from ..pipeline_utils import DiffusionPipeline from .pipeline_output import AnimateDiffPipelineOutput @@ -52,14 +53,21 @@ >>> from io import BytesIO >>> from PIL import Image - >>> adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2", mindspore_dtype=mindspore.float16) - >>> pipe = AnimateDiffVideoToVideoPipeline.from_pretrained("SG161222/Realistic_Vision_V5.1_noVAE", motion_adapter=adapter) - >>> pipe.scheduler = DDIMScheduler(beta_schedule="linear", steps_offset=1, clip_sample=False) + >>> adapter = MotionAdapter.from_pretrained( + ... "guoyww/animatediff-motion-adapter-v1-5-2", mindspore_dtype=mindspore.float16 + ... ) + >>> pipe = AnimateDiffVideoToVideoPipeline.from_pretrained( + ... "SG161222/Realistic_Vision_V5.1_noVAE", motion_adapter=adapter + ... ) + >>> pipe.scheduler = DDIMScheduler( + ... beta_schedule="linear", steps_offset=1, clip_sample=False, timespace_spacing="linspace" + ... ) + >>> def load_video(file_path: str): ... images = [] - ... - ... if file_path.startswith(('http://', 'https://')): + + ... if file_path.startswith(("http://", "https://")): ... # If the file_path is a URL ... response = requests.get(file_path) ... response.raise_for_status() @@ -68,43 +76,26 @@ ... else: ... # Assuming it's a local file path ... vid = imageio.get_reader(file_path) - ... + ... for frame in vid: ... pil_image = Image.fromarray(frame) ... images.append(pil_image) - ... + ... return images - >>> video = load_video("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-vid2vid-input-1.gif") - >>> output = pipe(video=video, prompt="panda playing a guitar, on a boat, in the ocean, high quality", strength=0.5) - >>> frames = output[0][0] + + >>> video = load_video( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-vid2vid-input-1.gif" + ... ) + >>> output = pipe( + ... video=video, prompt="panda playing a guitar, on a boat, in the ocean, high quality", strength=0.5 + ... ) + >>> frames = output.frames[0] >>> export_to_gif(frames, "animation.gif") ``` """ -# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid -def tensor2vid(video: ms.Tensor, processor: "VaeImageProcessor", output_type: str = "np"): - batch_size, channels, num_frames, height, width = video.shape - outputs = [] - for batch_idx in range(batch_size): - batch_vid = video[batch_idx].permute(1, 0, 2, 3) - batch_output = processor.postprocess(batch_vid, output_type) - - outputs.append(batch_output) - - if output_type == "np": - outputs = np.stack(outputs) - - elif output_type == "pt": - outputs = ops.stack(outputs) - - elif not output_type == "pil": - raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']") - - return outputs - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents def retrieve_latents( vae, encoder_output: ms.Tensor, generator: Optional[np.random.Generator] = None, sample_mode: str = "sample" @@ -125,6 +116,7 @@ def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, **kwargs, ): """ @@ -135,17 +127,21 @@ def retrieve_timesteps( scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): - The number of diffusion steps used when generating samples with a pre-trained model. If used, - `timesteps` must be `None`. + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. timesteps (`List[int]`, *optional*): - Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default - timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` - must be `None`. + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[ms.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: @@ -156,6 +152,16 @@ def retrieve_timesteps( scheduler.set_timesteps(timesteps=timesteps, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, **kwargs) timesteps = scheduler.timesteps @@ -233,7 +239,7 @@ def __init__( image_encoder=image_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt # noqa: E501 def encode_prompt( @@ -401,9 +407,10 @@ def encode_prompt( negative_prompt_embeds = negative_prompt_embeds.tile((1, num_images_per_prompt, 1)) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - if isinstance(self, LoraLoaderMixin): - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder is not None: + if isinstance(self, LoraLoaderMixin): + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) return prompt_embeds, negative_prompt_embeds @@ -417,17 +424,15 @@ def encode_image(self, image, num_images_per_prompt, output_hidden_states=None): image = image.to(dtype=dtype) if output_hidden_states: - image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True)[2][-2] image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_enc_hidden_states = self.image_encoder( - ops.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] + uncond_image_enc_hidden_states = self.image_encoder(ops.zeros_like(image), output_hidden_states=True)[2][-2] uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( num_images_per_prompt, dim=0 ) return image_enc_hidden_states, uncond_image_enc_hidden_states else: - image_embeds = self.image_encoder(image).image_embeds + image_embeds = self.image_encoder(image)[0] image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = ops.zeros_like(image_embeds) @@ -606,16 +611,7 @@ def prepare_latents( generator, latents=None, ): - # video must be a list of list of images - # the outer list denotes having multiple videos as input, whereas inner list means the frames of the video - # as a list of images - if not isinstance(video[0], list): - video = [video] if latents is None: - video = ops.cat( - [self.image_processor.preprocess(vid, height=height, width=width).unsqueeze(0) for vid in video], axis=0 - ) - video = video.to(dtype=dtype) num_frames = video.shape[1] else: num_frames = latents.shape[2] @@ -720,6 +716,7 @@ def __call__( width: Optional[int] = None, num_inference_steps: int = 50, timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, guidance_scale: float = 7.5, strength: float = 0.8, negative_prompt: Optional[Union[str, List[str]]] = None, @@ -753,6 +750,14 @@ def __call__( num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality videos at the expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. strength (`float`, *optional*, defaults to 0.8): Higher strength leads to more differences between original video and generated video. guidance_scale (`float`, *optional*, defaults to 7.5): @@ -780,17 +785,16 @@ def __call__( not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. - ip_adapter_image_embeds (`List[ms.Tensor]`, *optional*): - Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters. - Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding - if `do_classifier_free_guidance` is set to `True`. - If not provided, embeddings are computed from the `ip_adapter_image` input argument. + ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated video. Choose between `ms.Tensor`, `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `False`): - Whether or not to return a [`AnimateDiffPipelineOutput`] instead - of a plain tuple. + Whether or not to return a [`AnimateDiffPipelineOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). @@ -805,7 +809,7 @@ def __call__( callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the - `._callback_tensor_inputs` attribute of your pipeine class. + `._callback_tensor_inputs` attribute of your pipeline class. Examples: @@ -879,11 +883,16 @@ def __call__( ) # 4. Prepare timesteps - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, timesteps) + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, timesteps, sigmas) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength) latent_timestep = timesteps[:1].tile((batch_size * num_videos_per_prompt,)) # 5. Prepare latent variables + if latents is None: + video = self.video_processor.preprocess_video(video, height=height, width=width) + # Move the number of frames before the number of channels. + video = video.permute(0, 2, 1, 3, 4) + video = video.to(prompt_embeds.dtype) num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( video=video, @@ -915,7 +924,7 @@ def __call__( num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order # 8. Denoising loop - with self.progress_bar(total=num_inference_steps) as progress_bar: + with self.progress_bar(total=self._num_timesteps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = ops.cat([latents] * 2) if self.do_classifier_free_guidance else latents @@ -965,7 +974,7 @@ def __call__( video = latents else: video_tensor = self.decode_latents(latents) - video = tensor2vid(video_tensor, self.image_processor, output_type=output_type) + video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) if not return_dict: return (video,) diff --git a/mindone/diffusers/pipelines/animatediff/pipeline_output.py b/mindone/diffusers/pipelines/animatediff/pipeline_output.py index f9e58f0f97..ef1b55dbd2 100644 --- a/mindone/diffusers/pipelines/animatediff/pipeline_output.py +++ b/mindone/diffusers/pipelines/animatediff/pipeline_output.py @@ -15,8 +15,8 @@ class AnimateDiffPipelineOutput(BaseOutput): Output class for AnimateDiff pipelines. Args: - frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): - List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing denoised + frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape `(batch_size, num_frames, channels, height, width)` """ diff --git a/mindone/diffusers/pipelines/controlnet/multicontrolnet.py b/mindone/diffusers/pipelines/controlnet/multicontrolnet.py index ae7fb0be06..7733b8ad22 100644 --- a/mindone/diffusers/pipelines/controlnet/multicontrolnet.py +++ b/mindone/diffusers/pipelines/controlnet/multicontrolnet.py @@ -99,20 +99,16 @@ def save_pretrained( variant (`str`, *optional*): If specified, weights are saved in the format pytorch_model..bin. """ - idx = 0 - model_path_to_save = save_directory - for controlnet in self.nets: + for idx, controlnet in enumerate(self.nets): + suffix = "" if idx == 0 else f"_{idx}" controlnet.save_pretrained( - model_path_to_save, + save_directory + suffix, is_main_process=is_main_process, save_function=save_function, safe_serialization=safe_serialization, variant=variant, ) - idx += 1 - model_path_to_save = model_path_to_save + f"_{idx}" - @classmethod def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike]], **kwargs): r""" diff --git a/mindone/diffusers/pipelines/controlnet/pipeline_controlnet.py b/mindone/diffusers/pipelines/controlnet/pipeline_controlnet.py index 8bfd0d6675..791fc07b9a 100644 --- a/mindone/diffusers/pipelines/controlnet/pipeline_controlnet.py +++ b/mindone/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -24,6 +24,7 @@ from mindspore import ops from ....transformers import CLIPTextModel, CLIPVisionModelWithProjection +from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel @@ -84,6 +85,7 @@ def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, **kwargs, ): """ @@ -94,17 +96,21 @@ def retrieve_timesteps( scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): - The number of diffusion steps used when generating samples with a pre-trained model. If used, - `timesteps` must be `None`. + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. timesteps (`List[int]`, *optional*): - Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default - timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` - must be `None`. + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[ms.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: @@ -115,6 +121,16 @@ def retrieve_timesteps( scheduler.set_timesteps(timesteps=timesteps, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, **kwargs) timesteps = scheduler.timesteps @@ -419,9 +435,10 @@ def encode_prompt( negative_prompt_embeds = negative_prompt_embeds.tile((1, num_images_per_prompt, 1)) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - if isinstance(self, LoraLoaderMixin): - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder is not None: + if isinstance(self, LoraLoaderMixin): + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) return prompt_embeds, negative_prompt_embeds @@ -435,17 +452,15 @@ def encode_image(self, image, num_images_per_prompt, output_hidden_states=None): image = image.to(dtype=dtype) if output_hidden_states: - image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True)[2][-2] image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_enc_hidden_states = self.image_encoder( - ops.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] + uncond_image_enc_hidden_states = self.image_encoder(ops.zeros_like(image), output_hidden_states=True)[2][-2] uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( num_images_per_prompt, dim=0 ) return image_enc_hidden_states, uncond_image_enc_hidden_states else: - image_embeds = self.image_encoder(image).image_embeds + image_embeds = self.image_encoder(image)[0] image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = ops.zeros_like(image_embeds) @@ -629,9 +644,9 @@ def check_inputs( raise ValueError( f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." # noqa: E501 ) - - for image_ in image: - self.check_image(image_, prompt, prompt_embeds) + else: + for image_ in image: + self.check_image(image_, prompt, prompt_embeds) else: assert False @@ -766,7 +781,12 @@ def prepare_image( # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, generator, latents=None): - shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" @@ -785,20 +805,22 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype return latents # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding - def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=ms.float32): + def get_guidance_scale_embedding( + self, w: ms.Tensor, embedding_dim: int = 512, dtype: ms.Type = ms.float32 + ) -> ms.Tensor: """ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 Args: - w (`ms.Tensor`): + w (`torch.Tensor`): Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. embedding_dim (`int`, *optional*, defaults to 512): Dimension of the embeddings to generate. - dtype: + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): Data type of the generated embeddings. Returns: - `ms.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. """ assert len(w.shape) == 1 w = w * 1000.0 @@ -844,6 +866,7 @@ def __call__( width: Optional[int] = None, num_inference_steps: int = 50, timesteps: List[int] = None, + sigmas: List[float] = None, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, @@ -862,7 +885,9 @@ def __call__( control_guidance_start: Union[float, List[float]] = 0.0, control_guidance_end: Union[float, List[float]] = 1.0, clip_skip: Optional[int] = None, - callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], **kwargs, ): @@ -875,13 +900,13 @@ def __call__( image (`ms.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[ms.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: `List[List[ms.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): The ControlNet input condition to provide guidance to the `unet` for generation. If the type is - specified as `ms.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be - accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height - and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in - `init`, images must be passed as a list such that each element of the list can be correctly batched for - input to a single ControlNet. When `prompt` is a list, and if a list of images is passed for a single ControlNet, - each will be paired with each prompt in the `prompt` list. This also applies to multiple ControlNets, - where a list of image lists can be passed to batch for each prompt and each ControlNet. + specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted + as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or + width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, + images must be passed as a list such that each element of the list can be correctly batched for input + to a single ControlNet. When `prompt` is a list, and if a list of images is passed for a single + ControlNet, each will be paired with each prompt in the `prompt` list. This also applies to multiple + ControlNets, where a list of image lists can be passed to batch for each prompt and each ControlNet. height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The height in pixels of the generated image. width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): @@ -893,6 +918,10 @@ def __call__( Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. @@ -919,10 +948,10 @@ def __call__( not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. ip_adapter_image_embeds (`List[ms.Tensor]`, *optional*): - Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters. - Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding - if `do_classifier_free_guidance` is set to `True`. - If not provided, embeddings are computed from the `ip_adapter_image` input argument. + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `False`): @@ -952,14 +981,14 @@ def __call__( Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. callback_on_step_end (`Callable`, *optional*): - A function that calls at the end of each denoising steps during the inference. The function is called - with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, - callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by - `callback_on_step_end_tensor_inputs`. + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the - `._callback_tensor_inputs` attribute of your pipeine class. + `._callback_tensor_inputs` attribute of your pipeline class. Examples: @@ -987,6 +1016,9 @@ def __call__( "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", ) + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + controlnet = self.controlnet # align format for control guidance @@ -1108,7 +1140,7 @@ def __call__( assert False # 5. Prepare timesteps - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, timesteps) + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, timesteps, sigmas) self._num_timesteps = len(timesteps) # 6. Prepare latent variables diff --git a/mindone/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py b/mindone/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py index c9b61a206d..9939706f4b 100644 --- a/mindone/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +++ b/mindone/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py @@ -23,6 +23,7 @@ from mindspore import ops from ....transformers import CLIPTextModel, CLIPVisionModelWithProjection +from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel @@ -418,9 +419,10 @@ def encode_prompt( negative_prompt_embeds = negative_prompt_embeds.tile((1, num_images_per_prompt, 1)) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - if isinstance(self, LoraLoaderMixin): - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder is not None: + if isinstance(self, LoraLoaderMixin): + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) return prompt_embeds, negative_prompt_embeds @@ -434,17 +436,15 @@ def encode_image(self, image, num_images_per_prompt, output_hidden_states=None): image = image.to(dtype=dtype) if output_hidden_states: - image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True)[2][-2] image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_enc_hidden_states = self.image_encoder( - ops.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] + uncond_image_enc_hidden_states = self.image_encoder(ops.zeros_like(image), output_hidden_states=True)[2][-2] uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( num_images_per_prompt, dim=0 ) return image_enc_hidden_states, uncond_image_enc_hidden_states else: - image_embeds = self.image_encoder(image).image_embeds + image_embeds = self.image_encoder(image)[0] image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = ops.zeros_like(image_embeds) @@ -878,7 +878,9 @@ def __call__( control_guidance_start: Union[float, List[float]] = 0.0, control_guidance_end: Union[float, List[float]] = 1.0, clip_skip: Optional[int] = None, - callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], **kwargs, ): @@ -895,11 +897,11 @@ def __call__( control_image (`ms.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[ms.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: `List[List[ms.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): The ControlNet input condition to provide guidance to the `unet` for generation. If the type is - specified as `ms.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be - accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height - and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in - `init`, images must be passed as a list such that each element of the list can be correctly batched for - input to a single ControlNet. + specified as `ms.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted + as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or + width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, + images must be passed as a list such that each element of the list can be correctly batched for input + to a single ControlNet. height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The height in pixels of the generated image. width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): @@ -939,10 +941,10 @@ def __call__( not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. ip_adapter_image_embeds (`List[ms.Tensor]`, *optional*): - Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters. - Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding - if `do_classifier_free_guidance` is set to `True`. - If not provided, embeddings are computed from the `ip_adapter_image` input argument. + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `False`): @@ -965,15 +967,15 @@ def __call__( clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. - callback_on_step_end (`Callable`, *optional*): - A function that calls at the end of each denoising steps during the inference. The function is called - with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, - callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by - `callback_on_step_end_tensor_inputs`. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the - `._callback_tensor_inputs` attribute of your pipeine class. + `._callback_tensor_inputs` attribute of your pipeline class. Examples: @@ -1001,6 +1003,9 @@ def __call__( "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", ) + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + controlnet = self.controlnet # align format for control guidance @@ -1124,14 +1129,15 @@ def __call__( self._num_timesteps = len(timesteps) # 6. Prepare latent variables - latents = self.prepare_latents( - image, - latent_timestep, - batch_size, - num_images_per_prompt, - prompt_embeds.dtype, - generator, - ) + if latents is None: + latents = self.prepare_latents( + image, + latent_timestep, + batch_size, + num_images_per_prompt, + prompt_embeds.dtype, + generator, + ) # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) diff --git a/mindone/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/mindone/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py index b622e1b643..927d2bc9dc 100644 --- a/mindone/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +++ b/mindone/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py @@ -25,6 +25,7 @@ from mindspore import ops from ....transformers import CLIPTextModel, CLIPVisionModelWithProjection +from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel @@ -546,9 +547,10 @@ def encode_prompt( negative_prompt_embeds = negative_prompt_embeds.tile((1, num_images_per_prompt, 1)) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - if isinstance(self, LoraLoaderMixin): - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder is not None: + if isinstance(self, LoraLoaderMixin): + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) return prompt_embeds, negative_prompt_embeds @@ -562,17 +564,15 @@ def encode_image(self, image, num_images_per_prompt, output_hidden_states=None): image = image.to(dtype=dtype) if output_hidden_states: - image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True)[2][-2] image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_enc_hidden_states = self.image_encoder( - ops.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] + uncond_image_enc_hidden_states = self.image_encoder(ops.zeros_like(image), output_hidden_states=True)[2][-2] uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( num_images_per_prompt, dim=0 ) return image_enc_hidden_states, uncond_image_enc_hidden_states else: - image_embeds = self.image_encoder(image).image_embeds + image_embeds = self.image_encoder(image)[0] image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = ops.zeros_like(image_embeds) @@ -937,7 +937,12 @@ def prepare_latents( return_noise=False, return_image_latents=False, ): - shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" @@ -1090,7 +1095,9 @@ def __call__( control_guidance_start: Union[float, List[float]] = 0.0, control_guidance_end: Union[float, List[float]] = 1.0, clip_skip: Optional[int] = None, - callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], **kwargs, ): @@ -1118,21 +1125,22 @@ def __call__( control_image (`ms.Tensor`, `PIL.Image.Image`, `List[ms.Tensor]`, `List[PIL.Image.Image]`, `List[List[ms.Tensor]]`, or `List[List[PIL.Image.Image]]`): The ControlNet input condition to provide guidance to the `unet` for generation. If the type is - specified as `ms.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be - accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height - and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in - `init`, images must be passed as a list such that each element of the list can be correctly batched for - input to a single ControlNet. + specified as `ms.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted + as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or + width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, + images must be passed as a list such that each element of the list can be correctly batched for input + to a single ControlNet. height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The height in pixels of the generated image. width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The width in pixels of the generated image. padding_mask_crop (`int`, *optional*, defaults to `None`): - The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to image and mask_image. If - `padding_mask_crop` is not `None`, it will first find a rectangular region with the same aspect ration of the image and - contains all masked area, and then expand that area based on `padding_mask_crop`. The image and mask_image will then be cropped based on - the expanded area before resizing to the original image size for inpainting. This is useful when the masked area is small while the image is large # noqa: E501 - and contain information inreleant for inpainging, such as background. + The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to + image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region + with the same aspect ration of the image and contains all masked area, and then expand that area based + on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before + resizing to the original image size for inpainting. This is useful when the masked area is small while + the image is large and contain information irrelevant for inpainting, such as background. strength (`float`, *optional*, defaults to 1.0): Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a starting point and more noise is added the higher the `strength`. The number of denoising steps depends @@ -1168,10 +1176,10 @@ def __call__( not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. ip_adapter_image_embeds (`List[ms.Tensor]`, *optional*): - Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters. - Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding - if `do_classifier_free_guidance` is set to `True`. - If not provided, embeddings are computed from the `ip_adapter_image` input argument. + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `False`): @@ -1194,15 +1202,15 @@ def __call__( clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. - callback_on_step_end (`Callable`, *optional*): - A function that calls at the end of each denoising steps during the inference. The function is called - with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, - callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by - `callback_on_step_end_tensor_inputs`. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the - `._callback_tensor_inputs` attribute of your pipeine class. + `._callback_tensor_inputs` attribute of your pipeline class. Examples: @@ -1230,6 +1238,9 @@ def __call__( "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", ) + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + controlnet = self.controlnet # align format for control guidance diff --git a/mindone/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py b/mindone/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py index 0887d90833..3341934678 100644 --- a/mindone/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +++ b/mindone/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py @@ -23,6 +23,7 @@ from mindspore import ops from ....transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPVisionModelWithProjection +from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import ( FromSingleFileMixin, @@ -117,8 +118,8 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 """ - std_text = noise_pred_text.std(axis=list(range(1, noise_pred_text.ndim)), keepdims=True) - std_cfg = noise_cfg.std(axis=list(range(1, noise_cfg.ndim)), keepdims=True) + std_text = noise_pred_text.std(axis=tuple(range(1, noise_pred_text.ndim)), keepdims=True) + std_cfg = noise_cfg.std(axis=tuple(range(1, noise_cfg.ndim)), keepdims=True) # rescale the results from guidance (fixes overexposure) noise_pred_rescaled = noise_cfg * (std_text / std_cfg) # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images @@ -127,7 +128,11 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): class StableDiffusionXLControlNetInpaintPipeline( - DiffusionPipeline, StableDiffusionXLLoraLoaderMixin, FromSingleFileMixin, IPAdapterMixin + DiffusionPipeline, + StableDiffusionXLLoraLoaderMixin, + FromSingleFileMixin, + IPAdapterMixin, + TextualInversionLoaderMixin, ): r""" Pipeline for text-to-image generation using Stable Diffusion XL. @@ -136,6 +141,7 @@ class StableDiffusionXLControlNetInpaintPipeline( library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files @@ -167,8 +173,26 @@ class StableDiffusionXLControlNetInpaintPipeline( """ model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" - _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"] - _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + _optional_components = [ + "tokenizer", + "tokenizer_2", + "text_encoder", + "text_encoder_2", + "image_encoder", + "feature_extractor", + ] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "add_text_embeds", + "add_time_ids", + "negative_pooled_prompt_embeds", + "add_neg_time_ids", + "mask", + "masked_image_latents", + ] def __init__( self, @@ -178,7 +202,7 @@ def __init__( tokenizer: CLIPTokenizer, tokenizer_2: CLIPTokenizer, unet: UNet2DConditionModel, - controlnet: ControlNetModel, + controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], scheduler: KarrasDiffusionSchedulers, requires_aesthetics_score: bool = False, force_zeros_for_empty_prompt: bool = True, @@ -450,17 +474,15 @@ def encode_image(self, image, num_images_per_prompt, output_hidden_states=None): image = image.to(dtype=dtype) if output_hidden_states: - image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True)[2][-2] image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_enc_hidden_states = self.image_encoder( - ops.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] + uncond_image_enc_hidden_states = self.image_encoder(ops.zeros_like(image), output_hidden_states=True)[2][-2] uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( num_images_per_prompt, dim=0 ) return image_enc_hidden_states, uncond_image_enc_hidden_states else: - image_embeds = self.image_encoder(image).image_embeds + image_embeds = self.image_encoder(image)[0] image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = ops.zeros_like(image_embeds) @@ -817,7 +839,12 @@ def prepare_latents( return_noise=False, return_image_latents=False, ): - shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" @@ -954,7 +981,7 @@ def get_timesteps(self, num_inference_steps, strength, denoising_start=None): # because `num_inference_steps` might be even given that every timestep # (except the highest one) is duplicated. If `num_inference_steps` is even it would # mean that we cut the timesteps in the middle of the denoising step - # (between 1st and 2nd devirative) which leads to incorrect results. By adding 1 + # (between 1st and 2nd derivative) which leads to incorrect results. By adding 1 # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler num_inference_steps = num_inference_steps + 1 @@ -1081,7 +1108,9 @@ def __call__( aesthetic_score: float = 6.0, negative_aesthetic_score: float = 2.5, clip_skip: Optional[int] = None, - callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], **kwargs, ): @@ -1108,11 +1137,12 @@ def __call__( width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. padding_mask_crop (`int`, *optional*, defaults to `None`): - The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to image and mask_image. If - `padding_mask_crop` is not `None`, it will first find a rectangular region with the same aspect ration of the image and - contains all masked area, and then expand that area based on `padding_mask_crop`. The image and mask_image will then be cropped based on - the expanded area before resizing to the original image size for inpainting. This is useful when the masked area is small while the image is large # noqa: E501 - and contain information inreleant for inpainging, such as background. + The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to + image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region + with the same aspect ration of the image and contains all masked area, and then expand that area based + on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before + resizing to the original image size for inpainting. This is useful when the masked area is small while + the image is large and contain information irrelevant for inpainting, such as background. strength (`float`, *optional*, defaults to 0.9999): Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the @@ -1161,10 +1191,10 @@ def __call__( argument. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. ip_adapter_image_embeds (`List[ms.Tensor]`, *optional*): - Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters. - Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding - if `do_classifier_free_guidance` is set to `True`. - If not provided, embeddings are computed from the `ip_adapter_image` input argument. + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. pooled_prompt_embeds (`ms.Tensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. @@ -1219,11 +1249,11 @@ def __call__( clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. - callback_on_step_end (`Callable`, *optional*): - A function that calls at the end of each denoising steps during the inference. The function is called - with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, - callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by - `callback_on_step_end_tensor_inputs`. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the @@ -1253,6 +1283,9 @@ def __call__( "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", ) + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + controlnet = self.controlnet # align format for control guidance @@ -1506,10 +1539,7 @@ def denoising_value_valid(dnv): 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) for s, e in zip(control_guidance_start, control_guidance_end) ] - if isinstance(self.controlnet, MultiControlNetModel): - controlnet_keep.append(keeps) - else: - controlnet_keep.append(keeps[0]) + controlnet_keep.append(keeps if isinstance(controlnet, MultiControlNetModel) else keeps[0]) # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline height, width = latents.shape[-2:] @@ -1633,7 +1663,7 @@ def denoising_value_valid(dnv): down_block_res_samples = [ops.cat([ops.zeros_like(d), d]) for d in down_block_res_samples] mid_block_res_sample = ops.cat([ops.zeros_like(mid_block_res_sample), mid_block_res_sample]) - if ip_adapter_image is not None: + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: added_cond_kwargs["image_embeds"] = image_embeds if num_channels_unet == 9: diff --git a/mindone/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/mindone/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index d9dd545f88..6cd2dbaa1c 100644 --- a/mindone/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/mindone/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -24,6 +24,7 @@ from mindspore import ops from ....transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPVisionModelWithProjection +from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import ( FromSingleFileMixin, @@ -87,6 +88,63 @@ """ +# Copied from mindone.diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[ms.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and + the second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + class StableDiffusionXLControlNetPipeline( DiffusionPipeline, TextualInversionLoaderMixin, @@ -147,7 +205,15 @@ class StableDiffusionXLControlNetPipeline( "feature_extractor", "image_encoder", ] - _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "add_text_embeds", + "add_time_ids", + "negative_pooled_prompt_embeds", + "negative_add_time_ids", + ] def __init__( self, @@ -416,19 +482,17 @@ def encode_prompt( # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, num_images_per_prompt, output_hidden_states=None): - dtype = self.image_encoder.dtype + dtype = next(self.image_encoder.get_parameters()).dtype if not isinstance(image, ms.Tensor): image = self.feature_extractor(image, return_tensors="np").pixel_values - image = ms.Tensor.from_numpy(image) + image = ms.Tensor(image) image = image.to(dtype=dtype) if output_hidden_states: - image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True)[-1][-2] + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True)[2][-2] image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_enc_hidden_states = self.image_encoder(ops.zeros_like(image), output_hidden_states=True)[-1][ - -2 - ] + uncond_image_enc_hidden_states = self.image_encoder(ops.zeros_like(image), output_hidden_states=True)[2][-2] uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( num_images_per_prompt, dim=0 ) @@ -750,7 +814,12 @@ def prepare_image( # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, generator, latents=None): - shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" @@ -789,20 +858,22 @@ def upcast_vae(self): self.vae.to(dtype=ms.float32) # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding - def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=ms.float32): + def get_guidance_scale_embedding( + self, w: ms.Tensor, embedding_dim: int = 512, dtype: ms.Type = ms.float32 + ) -> ms.Tensor: """ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 Args: - w (`ms.Tensor`): + w (`torch.Tensor`): Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. embedding_dim (`int`, *optional*, defaults to 512): Dimension of the embeddings to generate. - dtype: + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): Data type of the generated embeddings. Returns: - `ms.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. """ assert len(w.shape) == 1 w = w * 1000.0 @@ -852,6 +923,8 @@ def __call__( height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, + timesteps: List[int] = None, + sigmas: List[float] = None, denoising_end: Optional[float] = None, guidance_scale: float = 5.0, negative_prompt: Optional[Union[str, List[str]]] = None, @@ -880,7 +953,9 @@ def __call__( negative_crops_coords_top_left: Tuple[int, int] = (0, 0), negative_target_size: Optional[Tuple[int, int]] = None, clip_skip: Optional[int] = None, - callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], **kwargs, ): @@ -896,11 +971,11 @@ def __call__( image (`ms.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[ms.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: `List[List[ms.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): The ControlNet input condition to provide guidance to the `unet` for generation. If the type is - specified as `ms.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be - accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height - and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in - `init`, images must be passed as a list such that each element of the list can be correctly batched for - input to a single ControlNet. + specified as `mindspore.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted + as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or + width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, + images must be passed as a list such that each element of the list can be correctly batched for input + to a single ControlNet. height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The height in pixels of the generated image. Anything below 512 pixels won't work well for [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) @@ -912,6 +987,14 @@ def __call__( num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. denoising_end (`float`, *optional*): When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be completed before it is intentionally prematurely terminated. As a result, the returned sample will @@ -955,10 +1038,10 @@ def __call__( argument. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. ip_adapter_image_embeds (`List[ms.Tensor]`, *optional*): - Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters. - Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding - if `do_classifier_free_guidance` is set to `True`. - If not provided, embeddings are computed from the `ip_adapter_image` input argument. + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `False`): @@ -1010,15 +1093,15 @@ def __call__( clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. - callback_on_step_end (`Callable`, *optional*): - A function that calls at the end of each denoising steps during the inference. The function is called - with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, - callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by - `callback_on_step_end_tensor_inputs`. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the - `._callback_tensor_inputs` attribute of your pipeine class. + `._callback_tensor_inputs` attribute of your pipeline class. Examples: @@ -1044,6 +1127,9 @@ def __call__( "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", ) + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + controlnet = self.controlnet # align format for control guidance @@ -1170,8 +1256,7 @@ def __call__( assert False # 5. Prepare timesteps - self.scheduler.set_timesteps(num_inference_steps) - timesteps = self.scheduler.timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, timesteps, sigmas) self._num_timesteps = len(timesteps) # 6. Prepare latent variables @@ -1364,6 +1449,12 @@ def __call__( latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) + negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): diff --git a/mindone/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py b/mindone/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py index 7945c84ca4..b43c0c4551 100644 --- a/mindone/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +++ b/mindone/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py @@ -24,8 +24,14 @@ from mindspore import ops from ....transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPVisionModelWithProjection +from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import IPAdapterMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin +from ...loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, +) from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import deprecate, logging, scale_lora_layers, unscale_lora_layers @@ -131,6 +137,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline( DiffusionPipeline, TextualInversionLoaderMixin, StableDiffusionXLLoraLoaderMixin, + FromSingleFileMixin, IPAdapterMixin, ): r""" @@ -195,7 +202,15 @@ class StableDiffusionXLControlNetImg2ImgPipeline( "feature_extractor", "image_encoder", ] - _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "add_text_embeds", + "add_time_ids", + "negative_pooled_prompt_embeds", + "add_neg_time_ids", + ] def __init__( self, @@ -468,17 +483,15 @@ def encode_image(self, image, num_images_per_prompt, output_hidden_states=None): image = image.to(dtype=dtype) if output_hidden_states: - image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True)[2][-2] image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_enc_hidden_states = self.image_encoder( - ops.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] + uncond_image_enc_hidden_states = self.image_encoder(ops.zeros_like(image), output_hidden_states=True)[2][-2] uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( num_images_per_prompt, dim=0 ) return image_enc_hidden_states, uncond_image_enc_hidden_states else: - image_embeds = self.image_encoder(image).image_embeds + image_embeds = self.image_encoder(image)[0] image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = ops.zeros_like(image_embeds) @@ -822,6 +835,12 @@ def prepare_latents( f"`image` has to be of type `mindspore.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) + latents_mean = latents_std = None + if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: + latents_mean = ms.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: + latents_std = ms.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) + image = image.to(dtype=dtype) batch_size = batch_size * num_images_per_prompt @@ -854,7 +873,12 @@ def prepare_latents( self.vae.to(dtype) init_latents = init_latents.to(dtype) - init_latents = self.vae.config.scaling_factor * init_latents + if latents_mean is not None and latents_std is not None: + latents_mean = latents_mean.to(dtype) + latents_std = latents_std.to(dtype) + init_latents = (init_latents - latents_mean) * self.vae.config.scaling_factor / latents_std + else: + init_latents = self.vae.config.scaling_factor * init_latents if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: # expand init_latents for batch_size @@ -1003,7 +1027,9 @@ def __call__( aesthetic_score: float = 6.0, negative_aesthetic_score: float = 2.5, clip_skip: Optional[int] = None, - callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], **kwargs, ): @@ -1024,11 +1050,11 @@ def __call__( control_image (`ms.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[ms.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: `List[List[ms.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If - the type is specified as `ms.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can - also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If - height and/or width are passed, `image` is resized according to them. If multiple ControlNets are - specified in init, images must be passed as a list such that each element of the list can be correctly - batched for input to a single controlnet. + the type is specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also + be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height + and/or width are passed, `image` is resized according to them. If multiple ControlNets are specified in + init, images must be passed as a list such that each element of the list can be correctly batched for + input to a single controlnet. height (`int`, *optional*, defaults to the size of control_image): The height in pixels of the generated image. Anything below 512 pixels won't work well for [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) @@ -1087,10 +1113,10 @@ def __call__( input argument. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. ip_adapter_image_embeds (`List[ms.Tensor]`, *optional*): - Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters. - Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding - if `do_classifier_free_guidance` is set to `True`. - If not provided, embeddings are computed from the `ip_adapter_image` input argument. + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. @@ -1152,15 +1178,15 @@ def __call__( clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. - callback_on_step_end (`Callable`, *optional*): - A function that calls at the end of each denoising steps during the inference. The function is called - with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, - callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by - `callback_on_step_end_tensor_inputs`. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the - `._callback_tensor_inputs` attribute of your pipeine class. + `._callback_tensor_inputs` attribute of your pipeline class. Examples: @@ -1186,6 +1212,9 @@ def __call__( "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", ) + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + controlnet = self.controlnet # align format for control guidance @@ -1321,15 +1350,16 @@ def __call__( self._num_timesteps = len(timesteps) # 6. Prepare latent variables - latents = self.prepare_latents( - image, - latent_timestep, - batch_size, - num_images_per_prompt, - prompt_embeds.dtype, - generator, - True, - ) + if latents is None: + latents = self.prepare_latents( + image, + latent_timestep, + batch_size, + num_images_per_prompt, + prompt_embeds.dtype, + generator, + True, + ) # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) @@ -1479,6 +1509,12 @@ def __call__( latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) + add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): diff --git a/mindone/diffusers/pipelines/controlnet_sd3/__init__.py b/mindone/diffusers/pipelines/controlnet_sd3/__init__.py new file mode 100644 index 0000000000..38ce1f98bf --- /dev/null +++ b/mindone/diffusers/pipelines/controlnet_sd3/__init__.py @@ -0,0 +1,18 @@ +from typing import TYPE_CHECKING + +from ...utils import _LazyModule + +_import_structure = {} +_import_structure["pipeline_stable_diffusion_3_controlnet"] = ["StableDiffusion3ControlNetPipeline"] + +if TYPE_CHECKING: + from .pipeline_stable_diffusion_3_controlnet import StableDiffusion3ControlNetPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) diff --git a/mindone/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py b/mindone/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py new file mode 100644 index 0000000000..01ac163465 --- /dev/null +++ b/mindone/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py @@ -0,0 +1,989 @@ +# Copyright 2024 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +from transformers import CLIPTokenizer, T5TokenizerFast + +import mindspore as ms +from mindspore import ops + +from ....transformers import CLIPTextModelWithProjection, T5EncoderModel +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin # , SD3LoraLoaderMixin +from ...models.autoencoders import AutoencoderKL +from ...models.controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel +from ...models.transformers import SD3Transformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import logging +from ...utils.mindspore_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from ..stable_diffusion_3.pipeline_output import StableDiffusion3PipelineOutput + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import mindspore + >>> from mindone.diffusers import StableDiffusion3ControlNetPipeline + >>> from mindone.diffusers.models import SD3ControlNetModel, SD3MultiControlNetModel + >>> from mindone.diffusers.utils import load_image + + >>> controlnet = SD3ControlNetModel.from_pretrained("InstantX/SD3-Controlnet-Canny", mindspore_dtype=mindspore.float16) + + >>> pipe = StableDiffusion3ControlNetPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-3-medium-diffusers", controlnet=controlnet, mindspore_dtype=mindspore.float16 + ... ) + >>> control_image = load_image("https://huggingface.co/InstantX/SD3-Controlnet-Canny/resolve/main/canny.jpg") + >>> prompt = "A girl holding a sign that says InstantX" + >>> image = pipe(prompt, control_image=control_image, controlnet_conditioning_scale=0.7)[0][0] + >>> image.save("sd3.png") + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + timesteps: Optional[List[int]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, + `timesteps` must be `None`. + timesteps (`List[int]`, *optional*): + Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default + timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` + must be `None`. + + Returns: + `Tuple[ms.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and + the second element is the number of inference steps. + """ + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class StableDiffusion3ControlNetPipeline(DiffusionPipeline, FromSingleFileMixin): + r""" + Args: + transformer ([`SD3Transformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModelWithProjection`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant, + with an additional added projection layer that is initialized with a diagonal matrix with the `hidden_size` + as its dimension. + text_encoder_2 ([`CLIPTextModelWithProjection`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. + text_encoder_3 ([`T5EncoderModel`]): + Frozen text-encoder. Stable Diffusion 3 uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the + [t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`CLIPTokenizer`): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_3 (`T5TokenizerFast`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + controlnet ([`SD3ControlNetModel`] or `List[SD3ControlNetModel]` or [`SD3MultiControlNetModel`]): + Provides additional conditioning to the `unet` during the denoising process. If you set multiple + ControlNets as a list, the outputs from each ControlNet are added together to create one combined + additional conditioning. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"] + + def __init__( + self, + transformer: SD3Transformer2DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer_2: CLIPTokenizer, + text_encoder_3: T5EncoderModel, + tokenizer_3: T5TokenizerFast, + controlnet: Union[ + SD3ControlNetModel, List[SD3ControlNetModel], Tuple[SD3ControlNetModel], SD3MultiControlNetModel + ], + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + text_encoder_3=text_encoder_3, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + tokenizer_3=tokenizer_3, + transformer=transformer, + scheduler=scheduler, + controlnet=controlnet, + ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + self.default_sample_size = ( + self.transformer.config.sample_size + if hasattr(self, "transformer") and self.transformer is not None + else 128 + ) + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + dtype=None, + ): + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if self.text_encoder_3 is None: + return ops.zeros( + (batch_size, self.tokenizer_max_length, self.transformer.config.joint_attention_dim), + dtype=dtype, + ) + + text_inputs = self.tokenizer_3( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + add_special_tokens=True, + return_tensors="np", + ) + text_input_ids = ms.Tensor.from_numpy(text_inputs.input_ids) + untruncated_ids = ms.Tensor.from_numpy( + self.tokenizer_3(prompt, padding="longest", return_tensors="np").input_ids + ) + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not ops.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_3(text_input_ids)[0] + + dtype = self.text_encoder_3.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.tile((1, num_images_per_prompt, 1)) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_clip_prompt_embeds + def _get_clip_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + clip_skip: Optional[int] = None, + clip_model_index: int = 0, + ): + clip_tokenizers = [self.tokenizer, self.tokenizer_2] + clip_text_encoders = [self.text_encoder, self.text_encoder_2] + + tokenizer = clip_tokenizers[clip_model_index] + text_encoder = clip_text_encoders[clip_model_index] + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_tensors="np", + ) + + text_input_ids = ms.Tensor.from_numpy(text_inputs.input_ids) + untruncated_ids = ms.Tensor.from_numpy(tokenizer(prompt, padding="longest", return_tensors="np").input_ids) + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not ops.equal(text_input_ids, untruncated_ids): + removed_text = tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = text_encoder(text_input_ids, output_hidden_states=True) + pooled_prompt_embeds = prompt_embeds[0] + + if clip_skip is None: + prompt_embeds = prompt_embeds[2][-2] + else: + prompt_embeds = prompt_embeds[2][-(clip_skip + 2)] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype) + + _, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.tile((1, num_images_per_prompt, 1)) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.tile((1, num_images_per_prompt, 1)) + pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds, pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]], + prompt_3: Union[str, List[str]], + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt_3: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[ms.Tensor] = None, + negative_prompt_embeds: Optional[ms.Tensor] = None, + pooled_prompt_embeds: Optional[ms.Tensor] = None, + negative_pooled_prompt_embeds: Optional[ms.Tensor] = None, + clip_skip: Optional[int] = None, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + prompt_3 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is + used in all text-encoders + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and + `text_encoder_3`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`ms.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`ms.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`ms.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`ms.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + prompt_3 = prompt_3 or prompt + prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3 + + prompt_embed, pooled_prompt_embed = self._get_clip_prompt_embeds( + prompt=prompt, + num_images_per_prompt=num_images_per_prompt, + clip_skip=clip_skip, + clip_model_index=0, + ) + prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds( + prompt=prompt_2, + num_images_per_prompt=num_images_per_prompt, + clip_skip=clip_skip, + clip_model_index=1, + ) + clip_prompt_embeds = ops.cat([prompt_embed, prompt_2_embed], axis=-1) + + t5_prompt_embed = self._get_t5_prompt_embeds( + prompt=prompt_3, + num_images_per_prompt=num_images_per_prompt, + ) + + clip_prompt_embeds = ops.pad( + clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1]) + ) + + prompt_embeds = ops.cat([clip_prompt_embeds, t5_prompt_embed], axis=-2) + pooled_prompt_embeds = ops.cat([pooled_prompt_embed, pooled_prompt_2_embed], axis=-1) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + negative_prompt_3 = negative_prompt_3 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + negative_prompt_3 = ( + batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3 + ) + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embed, negative_pooled_prompt_embed = self._get_clip_prompt_embeds( + negative_prompt, + num_images_per_prompt=num_images_per_prompt, + clip_skip=None, + clip_model_index=0, + ) + negative_prompt_2_embed, negative_pooled_prompt_2_embed = self._get_clip_prompt_embeds( + negative_prompt_2, + num_images_per_prompt=num_images_per_prompt, + clip_skip=None, + clip_model_index=1, + ) + negative_clip_prompt_embeds = ops.cat([negative_prompt_embed, negative_prompt_2_embed], axis=-1) + + t5_negative_prompt_embed = self._get_t5_prompt_embeds( + prompt=negative_prompt_3, + num_images_per_prompt=num_images_per_prompt, + ) + + negative_clip_prompt_embeds = ops.pad( + negative_clip_prompt_embeds, + (0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]), + ) + + negative_prompt_embeds = ops.cat([negative_clip_prompt_embeds, t5_negative_prompt_embed], axis=-2) + negative_pooled_prompt_embeds = ops.cat( + [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], axis=-1 + ) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + def check_inputs( + self, + prompt, + prompt_2, + prompt_3, + height, + width, + negative_prompt=None, + negative_prompt_2=None, + negative_prompt_3=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" # noqa: E501 + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_3 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_3`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + elif prompt_3 is not None and (not isinstance(prompt_3, str) and not isinstance(prompt_3, list)): + raise ValueError(f"`prompt_3` has to be of type `str` or `list` but is {type(prompt_3)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_3 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_3`: {negative_prompt_3} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." # noqa: E501 + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." # noqa: E501 + ) + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_latents + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + generator, + latents=None, + ): + if latents is not None: + return latents.to(dtype=dtype) + + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, dtype=dtype) + + return latents + + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + if isinstance(image, ms.Tensor): + pass + else: + image = self.image_processor.preprocess(image, height=height, width=width) + + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = ops.cat([image] * 2) + + return image + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + prompt_3: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 28, + timesteps: List[int] = None, + guidance_scale: float = 7.0, + control_guidance_start: Union[float, List[float]] = 0.0, + control_guidance_end: Union[float, List[float]] = 1.0, + control_image: PipelineImageInput = None, + controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + controlnet_pooled_projections: Optional[ms.Tensor] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt_3: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[np.random.Generator, List[np.random.Generator]]] = None, + latents: Optional[ms.Tensor] = None, + prompt_embeds: Optional[ms.Tensor] = None, + negative_prompt_embeds: Optional[ms.Tensor] = None, + pooled_prompt_embeds: Optional[ms.Tensor] = None, + negative_pooled_prompt_embeds: Optional[ms.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = False, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead + prompt_3 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is + will be used instead + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the ControlNet starts applying. + control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the ControlNet stops applying. + control_image (`ms.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[ms.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: + `List[List[ms.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): + The ControlNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `ms.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted + as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or + width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, + images must be passed as a list such that each element of the list can be correctly batched for input + to a single ControlNet. + controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set + the corresponding scale as a list. + controlnet_pooled_projections (`ms.Tensor` of shape `(batch_size, projection_dim)`): + Embeddings projected from the embeddings of controlnet input conditions. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used instead + negative_prompt_3 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and + `text_encoder_3`. If not defined, `negative_prompt` is used instead + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`np.random.Generator` or `List[np.random.Generator]`, *optional*): + One or a list of [numpy generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`ms.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`ms.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`ms.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`ms.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`ms.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `False`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(self.controlnet.nets) if isinstance(self.controlnet, SD3MultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + prompt_3, + height, + width, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + dtype = self.transformer.dtype + + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_3=prompt_3, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + clip_skip=self.clip_skip, + num_images_per_prompt=num_images_per_prompt, + ) + + if self.do_classifier_free_guidance: + prompt_embeds = ops.cat([negative_prompt_embeds, prompt_embeds], axis=0) + pooled_prompt_embeds = ops.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], axis=0) + + # 3. Prepare control image + if isinstance(self.controlnet, SD3ControlNetModel): + control_image = self.prepare_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + dtype=dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=False, + ) + height, width = control_image.shape[-2:] + + control_image = self.vae.diag_gauss_dist.sample(self.vae.encode(control_image)[0]) + control_image = control_image * self.vae.config.scaling_factor + + elif isinstance(self.controlnet, SD3MultiControlNetModel): + control_images = [] + + for control_image_ in control_image: + control_image_ = self.prepare_image( + image=control_image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + dtype=dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=False, + ) + + control_image_ = self.vae.diag_gauss_dist.sample(self.vae.encode(control_image_)[0]) + control_image_ = control_image_ * self.vae.config.scaling_factor + + control_images.append(control_image_) + + control_image = control_images + else: + assert False + + if controlnet_pooled_projections is None: + controlnet_pooled_projections = ops.zeros_like(pooled_prompt_embeds) + else: + controlnet_pooled_projections = controlnet_pooled_projections or pooled_prompt_embeds + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, timesteps) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + generator, + latents, + ) + + # 6. Create tensor stating which controlnets to keep + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if isinstance(self.controlnet, SD3ControlNetModel) else keeps) + + # 7. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = ops.cat([latents] * 2) if self.do_classifier_free_guidance else latents + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.broadcast_to((latent_model_input.shape[0],)) + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + + # controlnet(s) inference + control_block_samples = self.controlnet( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + pooled_projections=controlnet_pooled_projections, + joint_attention_kwargs=self.joint_attention_kwargs, + controlnet_cond=control_image, + conditioning_scale=cond_scale, + return_dict=False, + )[0] + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + pooled_projections=pooled_prompt_embeds, + block_controlnet_hidden_states=ms.mutable(control_block_samples), + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if output_type == "latent": + image = latents + + else: + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + if not return_dict: + return (image,) + + return StableDiffusion3PipelineOutput(images=image) diff --git a/mindone/diffusers/pipelines/controlnet_xs/__init__.py b/mindone/diffusers/pipelines/controlnet_xs/__init__.py new file mode 100644 index 0000000000..b7de670b8c --- /dev/null +++ b/mindone/diffusers/pipelines/controlnet_xs/__init__.py @@ -0,0 +1,21 @@ +from typing import TYPE_CHECKING + +from ...utils import _LazyModule + +_import_structure = {} +_import_structure["pipeline_controlnet_xs"] = ["StableDiffusionControlNetXSPipeline"] +_import_structure["pipeline_controlnet_xs_sd_xl"] = ["StableDiffusionXLControlNetXSPipeline"] + + +if TYPE_CHECKING: + from .pipeline_controlnet_xs import StableDiffusionControlNetXSPipeline + from .pipeline_controlnet_xs_sd_xl import StableDiffusionXLControlNetXSPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) diff --git a/mindone/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py b/mindone/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py new file mode 100644 index 0000000000..26dccedf12 --- /dev/null +++ b/mindone/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py @@ -0,0 +1,870 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import PIL.Image +from transformers import CLIPImageProcessor, CLIPTokenizer + +import mindspore as ms +from mindspore import ops + +from ....transformers import CLIPTextModel +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ControlNetXSAdapter, UNet2DConditionModel, UNetControlNetXSModel +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import deprecate, logging, scale_lora_layers, unscale_lora_layers +from ...utils.mindspore_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput +from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> # !pip install opencv-python transformers + >>> from mindone.diffusers import StableDiffusionControlNetXSPipeline, ControlNetXSAdapter + >>> from mindone.diffusers.utils import load_image + >>> import numpy as np + >>> import mindspore + + >>> import cv2 + >>> from PIL import Image + + >>> prompt = "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting" + >>> negative_prompt = "low quality, bad quality, sketches" + + >>> # download an image + >>> image = load_image( + ... "https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png" + ... ) + + >>> # initialize the models and pipeline + >>> controlnet_conditioning_scale = 0.5 + + >>> controlnet = ControlNetXSAdapter.from_pretrained( + ... "UmerHA/Testing-ConrolNetXS-SD2.1-canny", mindspore_dtype=mindspore.float16 + ... ) + >>> pipe = StableDiffusionControlNetXSPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-2-1-base", controlnet=controlnet, mindspore_dtype=mindspore.float16 + ... ) + + >>> # get canny image + >>> image = np.array(image) + >>> image = cv2.Canny(image, 100, 200) + >>> image = image[:, :, None] + >>> image = np.concatenate([image, image, image], axis=2) + >>> canny_image = Image.fromarray(image) + >>> # generate image + >>> image = pipe( + ... prompt, controlnet_conditioning_scale=controlnet_conditioning_scale, image=canny_image + ... )[0][0] + ``` +""" + + +class StableDiffusionControlNetXSPipeline( + DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin +): + r""" + Pipeline for text-to-image generation using Stable Diffusion with ControlNet-XS guidance. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A [`UNet2DConditionModel`] used to create a UNetControlNetXSModel to denoise the encoded image latents. + controlnet ([`ControlNetXSAdapter`]): + A [`ControlNetXSAdapter`] to be used in combination with `unet` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details + about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor"] + _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: Union[UNet2DConditionModel, UNetControlNetXSModel], + controlnet: ControlNetXSAdapter, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if isinstance(unet, UNet2DConditionModel): + unet = UNetControlNetXSModel.from_unet(unet, controlnet) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[ms.Tensor] = None, + negative_prompt_embeds: Optional[ms.Tensor] = None, + lora_scale: Optional[float] = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." # noqa: E501 + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = ops.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[ms.Tensor] = None, + negative_prompt_embeds: Optional[ms.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`ms.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`ms.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="np", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="np").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not ops.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = ms.Tensor(text_inputs.attention_mask) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(ms.Tensor(text_input_ids), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + ms.Tensor(text_input_ids), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.tile((1, num_images_per_prompt, 1)) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="np", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = ms.Tensor(uncond_input.attention_mask) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + ms.Tensor(uncond_input.input_ids), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype) + + negative_prompt_embeds = negative_prompt_embeds.tile((1, num_images_per_prompt, 1)) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, LoraLoaderMixin): + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if ops.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="np") + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=ms.Tensor(safety_checker_input.pixel_values).to(dtype) + ) + + # Warning for safety checker operations here as it couldn't been done in construct() + if ops.any(has_nsfw_concept): + logger.warning( + "Potential NSFW content was detected in one or more images. A black image will be returned instead." + " Try again with a different prompt and/or seed." + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + image, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + controlnet_conditioning_scale=1.0, + control_guidance_start=0.0, + control_guidance_end=1.0, + callback_on_step_end_tensor_inputs=None, + ): + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" # noqa: E501 + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Check `image` and `controlnet_conditioning_scale` + if isinstance(self.unet, UNetControlNetXSModel): + self.check_image(image, prompt, prompt_embeds) + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + else: + assert False + + start, end = control_guidance_start, control_guidance_end + if start >= end: + raise ValueError( + f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." + ) + if start < 0.0: + raise ValueError(f"control guidance start: {start} can't be smaller than 0.") + if end > 1.0: + raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, ms.Tensor) + image_is_np = isinstance(image, np.ndarray) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], ms.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): + raise TypeError( + f"image must be passed and be one of PIL image, numpy array, mindspore tensor, list of PIL images, list of numpy arrays or list of mindspore tensors, but is {type(image)}" # noqa: E501 + ) + + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" # noqa: E501 + ) + + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + dtype, + do_classifier_free_guidance=False, + ): + image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=ms.float32) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(dtype=dtype) + + if do_classifier_free_guidance: + image = ops.cat([image] * 2) + + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, dtype=dtype) + else: + latents = latents.to(dtype) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + # wtf? The above line changes the dtype of latents from fp16 to fp32, so we need a casting. + latents = latents.to(dtype=dtype) + return latents + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.guidance_scale + def guidance_scale(self): + return self._guidance_scale + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.clip_skip + def clip_skip(self): + return self._clip_skip + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.do_classifier_free_guidance + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.cross_attention_kwargs + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.num_timesteps + def num_timesteps(self): + return self._num_timesteps + + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: PipelineImageInput = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[np.random.Generator, List[np.random.Generator]]] = None, + latents: Optional[ms.Tensor] = None, + prompt_embeds: Optional[ms.Tensor] = None, + negative_prompt_embeds: Optional[ms.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = False, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + control_guidance_start: float = 0.0, + control_guidance_end: float = 1.0, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + image (`ms.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[ms.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: + `List[List[ms.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): + The ControlNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `ms.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted + as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or + width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, + images must be passed as a list such that each element of the list can be correctly batched for input + to a single ControlNet. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`np.random.Generator` or `List[np.random.Generator]`, *optional*): + A [`np.random.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`ms.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`ms.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`ms.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `False`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set + the corresponding scale as a list. + control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the ControlNet starts applying. + control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the ControlNet stops applying. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeine class. + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + unet = self.unet + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + image, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + controlnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=clip_skip, + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = ops.cat([negative_prompt_embeds, prompt_embeds]) + + # 4. Prepare image + image = self.prepare_image( + image=image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + dtype=unet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + ) + height, width = image.shape[-2:] + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps) + timesteps = self.scheduler.timesteps + + # 6. Prepare latent variables + num_channels_latents = self.unet.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + generator, + latents, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = ops.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + apply_control = ( + i / len(timesteps) >= control_guidance_start and (i + 1) / len(timesteps) <= control_guidance_end + ) + noise_pred = self.unet( + sample=latent_model_input, + timestep=t, + encoder_hidden_states=prompt_embeds, + controlnet_cond=image, + conditioning_scale=controlnet_conditioning_scale, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + apply_control=apply_control, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[0] + image, has_nsfw_concept = self.run_safety_checker(image, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/mindone/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py b/mindone/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py new file mode 100644 index 0000000000..881bba4ac4 --- /dev/null +++ b/mindone/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py @@ -0,0 +1,1034 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL.Image +from transformers import CLIPImageProcessor, CLIPTokenizer + +import mindspore as ms +from mindspore import ops + +from ....transformers import CLIPTextModel, CLIPTextModelWithProjection +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ControlNetXSAdapter, UNet2DConditionModel, UNetControlNetXSModel +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import logging, scale_lora_layers, unscale_lora_layers +from ...utils.mindspore_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> # !pip install opencv-python transformers + >>> from mindone.diffusers import StableDiffusionXLControlNetXSPipeline, ControlNetXSAdapter, AutoencoderKL + >>> from mindone.diffusers.utils import load_image + >>> import numpy as np + >>> import mindspore + + >>> import cv2 + >>> from PIL import Image + + >>> prompt = "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting" + >>> negative_prompt = "low quality, bad quality, sketches" + + >>> # download an image + >>> image = load_image( + ... "https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png" + ... ) + + >>> # initialize the models and pipeline + >>> controlnet_conditioning_scale = 0.5 + >>> vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", mindspore_dtype=mindspore.float16) + >>> controlnet = ControlNetXSAdapter.from_pretrained( + ... "UmerHA/Testing-ConrolNetXS-SDXL-canny", mindspore_dtype=mindspore.float16 + ... ) + >>> pipe = StableDiffusionXLControlNetXSPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, mindspore_dtype=mindspore.float16 + ... ) + + >>> # get canny image + >>> image = np.array(image) + >>> image = cv2.Canny(image, 100, 200) + >>> image = image[:, :, None] + >>> image = np.concatenate([image, image, image], axis=2) + >>> canny_image = Image.fromarray(image) + + >>> # generate image + >>> image = pipe( + ... prompt, controlnet_conditioning_scale=controlnet_conditioning_scale, image=canny_image + ... )[0][0] + ``` +""" + + +class StableDiffusionXLControlNetXSPipeline( + DiffusionPipeline, + TextualInversionLoaderMixin, + StableDiffusionXLLoraLoaderMixin, + FromSingleFileMixin, +): + r""" + Pipeline for text-to-image generation using Stable Diffusion XL with ControlNet-XS guidance. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + text_encoder_2 ([`~transformers.CLIPTextModelWithProjection`]): + Second frozen text-encoder + ([laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + tokenizer_2 ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A [`UNet2DConditionModel`] used to create a UNetControlNetXSModel to denoise the encoded image latents. + controlnet ([`ControlNetXSAdapter`]): + A [`ControlNetXSAdapter`] to be used in combination with `unet` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): + Whether the negative prompt embeddings should always be set to 0. Also see the config of + `stabilityai/stable-diffusion-xl-base-1-0`. + add_watermarker (`bool`, *optional*): + Whether to use the [invisible_watermark](https://github.com/ShieldMnt/invisible-watermark/) library to + watermark output images. If not defined, it defaults to `True` if the package is installed; otherwise no + watermarker is used. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" + _optional_components = [ + "tokenizer", + "tokenizer_2", + "text_encoder", + "text_encoder_2", + "feature_extractor", + ] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "add_text_embeds", + "add_time_ids", + "negative_pooled_prompt_embeds", + "negative_add_time_ids", + ] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: Union[UNet2DConditionModel, UNetControlNetXSModel], + controlnet: ControlNetXSAdapter, + scheduler: KarrasDiffusionSchedulers, + force_zeros_for_empty_prompt: bool = True, + add_watermarker: Optional[bool] = None, + feature_extractor: CLIPImageProcessor = None, + ): + super().__init__() + + if isinstance(unet, UNet2DConditionModel): + unet = UNetControlNetXSModel.from_unet(unet, controlnet) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) + self.watermark = None + + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt + def encode_prompt( + self, + prompt: str, + prompt_2: Optional[str] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[str] = None, + negative_prompt_2: Optional[str] = None, + prompt_embeds: Optional[ms.Tensor] = None, + negative_prompt_embeds: Optional[ms.Tensor] = None, + pooled_prompt_embeds: Optional[ms.Tensor] = None, + negative_pooled_prompt_embeds: Optional[ms.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`ms.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`ms.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`ms.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`ms.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # textual inversion: process multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="np", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="np").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not ops.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(ms.Tensor(text_input_ids), output_hidden_states=True) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + if clip_skip is None: + prompt_embeds = prompt_embeds[-1][-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 2)] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = ops.concat(prompt_embeds_list, axis=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = ops.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = ops.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + + uncond_tokens: List[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="np", + ) + + negative_prompt_embeds = text_encoder( + ms.Tensor(uncond_input.input_ids), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds[-1][-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = ops.concat(negative_prompt_embeds_list, axis=-1) + + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.tile((1, num_images_per_prompt, 1)) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype) + + negative_prompt_embeds = negative_prompt_embeds.tile((1, num_images_per_prompt, 1)) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.tile((1, num_images_per_prompt)).view( + bs_embed * num_images_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.tile((1, num_images_per_prompt)).view( + bs_embed * num_images_per_prompt, -1 + ) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin): + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin): + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + prompt_2, + image, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + controlnet_conditioning_scale=1.0, + control_guidance_start=0.0, + control_guidance_end=1.0, + callback_on_step_end_tensor_inputs=None, + ): + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" # noqa: E501 + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." # noqa: E501 + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." # noqa: E501 + ) + + # Check `image` and ``controlnet_conditioning_scale`` + if isinstance(self.unet, UNetControlNetXSModel): + self.check_image(image, prompt, prompt_embeds) + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + else: + assert False + + start, end = control_guidance_start, control_guidance_end + if start >= end: + raise ValueError( + f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." + ) + if start < 0.0: + raise ValueError(f"control guidance start: {start} can't be smaller than 0.") + if end > 1.0: + raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, ms.Tensor) + image_is_np = isinstance(image, np.ndarray) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], ms.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): + raise TypeError( + f"image must be passed and be one of PIL image, numpy array, mindspore tensor, list of PIL images, list of numpy arrays or list of mindspore tensors, but is {type(image)}" # noqa: E501 + ) + + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" # noqa: E501 + ) + + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + dtype, + do_classifier_free_guidance=False, + ): + image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=ms.float32) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(dtype=dtype) + + if do_classifier_free_guidance: + image = ops.cat([image] * 2) + + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, dtype=dtype) + else: + latents = latents.to(dtype) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + # wtf? The above line changes the dtype of latents from fp16 to fp32, so we need a casting. + latents = latents.to(dtype=dtype) + return latents + + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + ): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = self.unet.base_add_embedding.linear_1.in_channels + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." # noqa: E501 + ) + + add_time_ids = ms.Tensor([add_time_ids], dtype=dtype) + return add_time_ids + + def upcast_vae(self): + self.vae.to(dtype=ms.float32) + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.guidance_scale + def guidance_scale(self): + return self._guidance_scale + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.clip_skip + def clip_skip(self): + return self._clip_skip + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.do_classifier_free_guidance + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.cross_attention_kwargs + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.num_timesteps + def num_timesteps(self): + return self._num_timesteps + + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + image: PipelineImageInput = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[np.random.Generator, List[np.random.Generator]]] = None, + latents: Optional[ms.Tensor] = None, + prompt_embeds: Optional[ms.Tensor] = None, + negative_prompt_embeds: Optional[ms.Tensor] = None, + pooled_prompt_embeds: Optional[ms.Tensor] = None, + negative_pooled_prompt_embeds: Optional[ms.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = False, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + control_guidance_start: float = 0.0, + control_guidance_end: float = 1.0, + original_size: Tuple[int, int] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Tuple[int, int] = None, + negative_original_size: Optional[Tuple[int, int]] = None, + negative_crops_coords_top_left: Tuple[int, int] = (0, 0), + negative_target_size: Optional[Tuple[int, int]] = None, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders. + image (`ms.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[ms.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: + `List[List[ms.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): + The ControlNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `ms.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted + as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or + width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, + images must be passed as a list such that each element of the list can be correctly batched for input + to a single ControlNet. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 5.0): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. This is sent to `tokenizer_2` + and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`np.random.Generator` or `List[np.random.Generator]`, *optional*): + A [`np.random.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`ms.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`ms.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`ms.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + pooled_prompt_embeds (`ms.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, pooled text embeddings are generated from `prompt` input argument. + negative_pooled_prompt_embeds (`ms.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs (prompt + weighting). If not provided, pooled `negative_prompt_embeds` are generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `False`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original `unet`. + control_guidance_start (`float`, *optional*, defaults to 0.0): + The percentage of total steps at which the ControlNet starts applying. + control_guidance_end (`float`, *optional*, defaults to 1.0): + The percentage of total steps at which the ControlNet stops applying. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeine class. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] is + returned, otherwise a `tuple` is returned containing the output images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + unet = self.unet + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + image, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + controlnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt, + prompt_2, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=clip_skip, + ) + + # 4. Prepare image + if isinstance(unet, UNetControlNetXSModel): + image = self.prepare_image( + image=image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + dtype=unet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + ) + height, width = image.shape[-2:] + else: + assert False + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps) + timesteps = self.scheduler.timesteps + + # 6. Prepare latent variables + num_channels_latents = self.unet.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + generator, + latents, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7.1 Prepare added time ids & embeddings + if isinstance(image, list): + original_size = original_size or image[0].shape[-2:] + else: + original_size = original_size or image.shape[-2:] + target_size = target_size or (height, width) + + add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + + if negative_original_size is not None and negative_target_size is not None: + negative_add_time_ids = self._get_add_time_ids( + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + else: + negative_add_time_ids = add_time_ids + + if do_classifier_free_guidance: + prompt_embeds = ops.cat([negative_prompt_embeds, prompt_embeds], axis=0) + add_text_embeds = ops.cat([negative_pooled_prompt_embeds, add_text_embeds], axis=0) + add_time_ids = ops.cat([negative_add_time_ids, add_time_ids], axis=0) + + add_time_ids = add_time_ids.tile((batch_size * num_images_per_prompt, 1)) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = ops.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + + # predict the noise residual + apply_control = ( + i / len(timesteps) >= control_guidance_start and (i + 1) / len(timesteps) <= control_guidance_end + ) + noise_pred = self.unet( + sample=latent_model_input, + timestep=t, + encoder_hidden_states=prompt_embeds, + controlnet_cond=image, + conditioning_scale=controlnet_conditioning_scale, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=ms.mutable(added_cond_kwargs), + return_dict=False, + apply_control=apply_control, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + # Do not cast manually here for max memory savings like original diffusers as it will be done in + # `if not output_type == "latent"` branch, and original casting outside has not corresponding + # 'cast back to fp16 if needed' which might raise error if pipeline is called repeatedly. + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == ms.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.get_parameters())).dtype) + + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=ms.float16) + else: + image = latents + + if not output_type == "latent": + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) diff --git a/mindone/diffusers/pipelines/deepfloyd_if/pipeline_if.py b/mindone/diffusers/pipelines/deepfloyd_if/pipeline_if.py index 0e1b3dada9..c82ecb10cc 100644 --- a/mindone/diffusers/pipelines/deepfloyd_if/pipeline_if.py +++ b/mindone/diffusers/pipelines/deepfloyd_if/pipeline_if.py @@ -652,6 +652,9 @@ def __call__( self.scheduler.set_timesteps(num_inference_steps) timesteps = self.scheduler.timesteps + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(0) + # 5. Prepare intermediate images intermediate_images = self.prepare_intermediate_images( batch_size * num_images_per_prompt, diff --git a/mindone/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py b/mindone/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py index a6b48ee16b..c31de94398 100644 --- a/mindone/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +++ b/mindone/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py @@ -593,7 +593,7 @@ def numpy_to_pt(images): for image_ in image: image_ = image_.convert("RGB") - image_ = resize(image_, self.unet.sample_size) + image_ = resize(image_, self.unet.config.sample_size) image_ = np.array(image_) image_ = image_.astype(np.float32) image_ = image_ / 127.5 - 1 @@ -613,12 +613,15 @@ def numpy_to_pt(images): return image + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps def get_timesteps(self, num_inference_steps, strength): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) - timesteps = self.scheduler.timesteps[t_start:] + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) return timesteps, num_inference_steps - t_start diff --git a/mindone/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py b/mindone/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py index b379dfade9..4dc6cf4c3c 100644 --- a/mindone/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +++ b/mindone/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py @@ -641,7 +641,7 @@ def numpy_to_pt(images): for image_ in image: image_ = image_.convert("RGB") - image_ = resize(image_, self.unet.sample_size) + image_ = resize(image_, self.unet.config.sample_size) image_ = np.array(image_) image_ = image_.astype(np.float32) image_ = image_ / 127.5 - 1 @@ -693,13 +693,15 @@ def preprocess_image(self, image, num_images_per_prompt): return image - # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.IFImg2ImgPipeline.get_timesteps + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps def get_timesteps(self, num_inference_steps, strength): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) - timesteps = self.scheduler.timesteps[t_start:] + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) return timesteps, num_inference_steps - t_start diff --git a/mindone/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py b/mindone/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py index a781697888..da22c084ea 100644 --- a/mindone/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +++ b/mindone/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py @@ -634,7 +634,7 @@ def numpy_to_pt(images): for image_ in image: image_ = image_.convert("RGB") - image_ = resize(image_, self.unet.sample_size) + image_ = resize(image_, self.unet.config.sample_size) image_ = np.array(image_) image_ = image_.astype(np.float32) image_ = image_ / 127.5 - 1 @@ -681,7 +681,7 @@ def preprocess_mask_image(self, mask_image) -> ms.Tensor: for mask_image_ in mask_image: mask_image_ = mask_image_.convert("L") - mask_image_ = resize(mask_image_, self.unet.sample_size) + mask_image_ = resize(mask_image_, self.unet.config.sample_size) mask_image_ = np.array(mask_image_) mask_image_ = mask_image_[None, None, :] new_mask_image.append(mask_image_) @@ -703,13 +703,15 @@ def preprocess_mask_image(self, mask_image) -> ms.Tensor: return mask_image - # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.IFImg2ImgPipeline.get_timesteps + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps def get_timesteps(self, num_inference_steps, strength): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) - timesteps = self.scheduler.timesteps[t_start:] + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) return timesteps, num_inference_steps - t_start diff --git a/mindone/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py b/mindone/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py index b8828faaec..b1642f3eda 100644 --- a/mindone/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +++ b/mindone/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py @@ -677,7 +677,7 @@ def numpy_to_pt(images): for image_ in image: image_ = image_.convert("RGB") - image_ = resize(image_, self.unet.sample_size) + image_ = resize(image_, self.unet.config.sample_size) image_ = np.array(image_) image_ = image_.astype(np.float32) image_ = image_ / 127.5 - 1 @@ -757,7 +757,7 @@ def preprocess_mask_image(self, mask_image) -> ms.Tensor: for mask_image_ in mask_image: mask_image_ = mask_image_.convert("L") - mask_image_ = resize(mask_image_, self.unet.sample_size) + mask_image_ = resize(mask_image_, self.unet.config.sample_size) mask_image_ = np.array(mask_image_) mask_image_ = mask_image_[None, None, :] new_mask_image.append(mask_image_) @@ -779,13 +779,15 @@ def preprocess_mask_image(self, mask_image) -> ms.Tensor: return mask_image - # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.IFImg2ImgPipeline.get_timesteps + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps def get_timesteps(self, num_inference_steps, strength): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) - timesteps = self.scheduler.timesteps[t_start:] + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) return timesteps, num_inference_steps - t_start diff --git a/mindone/diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py b/mindone/diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py index 132439f0b2..641c87522f 100644 --- a/mindone/diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +++ b/mindone/diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py @@ -749,6 +749,9 @@ def __call__( self.scheduler.set_timesteps(num_inference_steps) timesteps = self.scheduler.timesteps + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(0) + # 5. Prepare intermediate images num_channels = self.unet.config.in_channels // 2 intermediate_images = self.prepare_intermediate_images( diff --git a/mindone/diffusers/pipelines/deepfloyd_if/watermark.py b/mindone/diffusers/pipelines/deepfloyd_if/watermark.py index a949b8d98f..c1b37d1a84 100644 --- a/mindone/diffusers/pipelines/deepfloyd_if/watermark.py +++ b/mindone/diffusers/pipelines/deepfloyd_if/watermark.py @@ -25,7 +25,7 @@ def __init__(self): self.watermark_image_as_pil = None def apply_watermark(self, images: List[PIL.Image.Image], sample_size=None): - # copied from https://github.com/deep-floyd/IF/blob/b77482e36ca2031cb94dbca1001fc1e6400bf4ab/deepfloyd_if/modules/base.py#L287 + # Copied from https://github.com/deep-floyd/IF/blob/b77482e36ca2031cb94dbca1001fc1e6400bf4ab/deepfloyd_if/modules/base.py#L287 h = images[0].height w = images[0].width diff --git a/mindone/diffusers/pipelines/dit/pipeline_dit.py b/mindone/diffusers/pipelines/dit/pipeline_dit.py index 7fa5d0f24c..6abd60705f 100644 --- a/mindone/diffusers/pipelines/dit/pipeline_dit.py +++ b/mindone/diffusers/pipelines/dit/pipeline_dit.py @@ -25,7 +25,7 @@ import mindspore as ms from mindspore import ops -from ...models import AutoencoderKL, Transformer2DModel +from ...models import AutoencoderKL, DiTTransformer2DModel from ...schedulers import KarrasDiffusionSchedulers from ...utils.mindspore_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput @@ -39,8 +39,8 @@ class DiTPipeline(DiffusionPipeline): implemented for all pipelines (downloading, saving, running on a particular device, etc.). Parameters: - transformer ([`Transformer2DModel`]): - A class conditioned `Transformer2DModel` to denoise the encoded image latents. + transformer ([`DiTTransformer2DModel`]): + A class conditioned `DiTTransformer2DModel` to denoise the encoded image latents. vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. scheduler ([`DDIMScheduler`]): @@ -51,7 +51,7 @@ class DiTPipeline(DiffusionPipeline): def __init__( self, - transformer: Transformer2DModel, + transformer: DiTTransformer2DModel, vae: AutoencoderKL, scheduler: KarrasDiffusionSchedulers, id2label: Optional[Dict[int, str]] = None, diff --git a/mindone/diffusers/pipelines/hunyuandit/__init__.py b/mindone/diffusers/pipelines/hunyuandit/__init__.py new file mode 100644 index 0000000000..d7ef449ccf --- /dev/null +++ b/mindone/diffusers/pipelines/hunyuandit/__init__.py @@ -0,0 +1,20 @@ +from typing import TYPE_CHECKING + +from ...utils import _LazyModule + +_import_structure = { + "pipeline_hunyuandit": ["HunyuanDiTPipeline"], +} + +if TYPE_CHECKING: + from .pipeline_hunyuandit import HunyuanDiTPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) diff --git a/mindone/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py b/mindone/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py new file mode 100644 index 0000000000..461ce6f1e1 --- /dev/null +++ b/mindone/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py @@ -0,0 +1,872 @@ +# Copyright 2024 HunyuanDiT Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +from transformers import BertTokenizer, CLIPImageProcessor, MT5Tokenizer + +import mindspore as ms +from mindspore import ops + +from ....transformers import BertModel, T5EncoderModel +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import VaeImageProcessor +from ...models import AutoencoderKL, HunyuanDiT2DModel +from ...models.embeddings import get_2d_rotary_pos_embed +from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from ...schedulers import DDPMScheduler +from ...utils import logging +from ...utils.mindspore_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from ..stable_diffusion import StableDiffusionPipelineOutput + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import mindspore + >>> from mindone.diffusers import HunyuanDiTPipeline + + >>> pipe = HunyuanDiTPipeline.from_pretrained( + ... "Tencent-Hunyuan/HunyuanDiT-Diffusers", mindspore_dtype=mindspore.float16 + ... ) + + >>> # You may also use English prompt as HunyuanDiT supports both English and Chinese + >>> # prompt = "An astronaut riding a horse" + >>> prompt = "一个宇航员在骑马" + >>> image = pipe(prompt)[0][0] + ``` +""" + +STANDARD_RATIO = np.array( + [ + 1.0, # 1:1 + 4.0 / 3.0, # 4:3 + 3.0 / 4.0, # 3:4 + 16.0 / 9.0, # 16:9 + 9.0 / 16.0, # 9:16 + ] +) +STANDARD_SHAPE = [ + [(1024, 1024), (1280, 1280)], # 1:1 + [(1024, 768), (1152, 864), (1280, 960)], # 4:3 + [(768, 1024), (864, 1152), (960, 1280)], # 3:4 + [(1280, 768)], # 16:9 + [(768, 1280)], # 9:16 +] +STANDARD_AREA = [np.array([w * h for w, h in shapes]) for shapes in STANDARD_SHAPE] +SUPPORTED_SHAPE = [ + (1024, 1024), + (1280, 1280), # 1:1 + (1024, 768), + (1152, 864), + (1280, 960), # 4:3 + (768, 1024), + (864, 1152), + (960, 1280), # 3:4 + (1280, 768), # 16:9 + (768, 1280), # 9:16 +] + + +def map_to_standard_shapes(target_width, target_height): + target_ratio = target_width / target_height + closest_ratio_idx = np.argmin(np.abs(STANDARD_RATIO - target_ratio)) + closest_area_idx = np.argmin(np.abs(STANDARD_AREA[closest_ratio_idx] - target_width * target_height)) + width, height = STANDARD_SHAPE[closest_ratio_idx][closest_area_idx] + return width, height + + +def get_resize_crop_region_for_grid(src, tgt_size): + th = tw = tgt_size + h, w = src + + r = h / w + + # resize + if r > 1: + resize_height = th + resize_width = int(round(th / h * w)) + else: + resize_width = tw + resize_height = int(round(tw / w * h)) + + crop_top = int(round((th - resize_height) / 2.0)) + crop_left = int(round((tw - resize_width) / 2.0)) + + return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(axis=tuple(range(1, noise_pred_text.ndim)), keepdims=True) + std_cfg = noise_cfg.std(axis=tuple(range(1, noise_cfg.ndim)), keepdims=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +class HunyuanDiTPipeline(DiffusionPipeline): + r""" + Pipeline for English/Chinese-to-image generation using HunyuanDiT. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + HunyuanDiT uses two text encoders: [mT5](https://huggingface.co/google/mt5-base) and [bilingual CLIP](fine-tuned by + ourselves) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. We use + `sdxl-vae-fp16-fix`. + text_encoder (Optional[`~transformers.BertModel`, `~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + HunyuanDiT uses a fine-tuned [bilingual CLIP]. + tokenizer (Optional[`~transformers.BertTokenizer`, `~transformers.CLIPTokenizer`]): + A `BertTokenizer` or `CLIPTokenizer` to tokenize text. + transformer ([`HunyuanDiT2DModel`]): + The HunyuanDiT model designed by Tencent Hunyuan. + text_encoder_2 (`T5EncoderModel`): + The mT5 embedder. Specifically, it is 't5-v1_1-xxl'. + tokenizer_2 (`MT5Tokenizer`): + The tokenizer for the mT5 embedder. + scheduler ([`DDPMScheduler`]): + A scheduler to be used in combination with HunyuanDiT to denoise the encoded image latents. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _optional_components = [ + "safety_checker", + "feature_extractor", + "text_encoder_2", + "tokenizer_2", + "text_encoder", + "tokenizer", + ] + _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "prompt_embeds_2", + "negative_prompt_embeds_2", + ] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: BertModel, + tokenizer: BertTokenizer, + transformer: HunyuanDiT2DModel, + scheduler: DDPMScheduler, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + text_encoder_2=T5EncoderModel, + tokenizer_2=MT5Tokenizer, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + transformer=transformer, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + text_encoder_2=text_encoder_2, + ) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + self.default_sample_size = ( + self.transformer.config.sample_size + if hasattr(self, "transformer") and self.transformer is not None + else 128 + ) + + def encode_prompt( + self, + prompt: str, + dtype: Optional[ms.Type] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[str] = None, + prompt_embeds: Optional[ms.Tensor] = None, + negative_prompt_embeds: Optional[ms.Tensor] = None, + prompt_attention_mask: Optional[ms.Tensor] = None, + negative_prompt_attention_mask: Optional[ms.Tensor] = None, + max_sequence_length: Optional[int] = None, + text_encoder_index: int = 0, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + dtype (`mindspore.Type`): + mindspore dtype + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`ms.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`ms.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + prompt_attention_mask (`ms.Tensor`, *optional*): + Attention mask for the prompt. Required when `prompt_embeds` is passed directly. + negative_prompt_attention_mask (`ms.Tensor`, *optional*): + Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly. + max_sequence_length (`int`, *optional*): maximum sequence length to use for the prompt. + text_encoder_index (`int`, *optional*): + Index of the text encoder to use. `0` for clip and `1` for T5. + """ + if dtype is None: + if self.text_encoder_2 is not None: + dtype = self.text_encoder_2.dtype + elif self.transformer is not None: + dtype = self.transformer.dtype + else: + dtype = None + + tokenizers = [self.tokenizer, self.tokenizer_2] + text_encoders = [self.text_encoder, self.text_encoder_2] + + tokenizer = tokenizers[text_encoder_index] + text_encoder = text_encoders[text_encoder_index] + + if max_sequence_length is None: + if text_encoder_index == 0: + max_length = 77 + if text_encoder_index == 1: + max_length = 256 + else: + max_length = max_sequence_length + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + return_tensors="np", + ) + text_input_ids = ms.Tensor.from_numpy(text_inputs.input_ids) + untruncated_ids = ms.Tensor.from_numpy(tokenizer(prompt, padding="longest", return_tensors="np").input_ids) + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not ops.equal(text_input_ids, untruncated_ids): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_attention_mask = ms.Tensor.from_numpy(text_inputs.attention_mask) + prompt_embeds = text_encoder( + text_input_ids, + attention_mask=prompt_attention_mask, + ) + prompt_embeds = prompt_embeds[0] + prompt_attention_mask = prompt_attention_mask.tile((num_images_per_prompt, 1)) + + prompt_embeds = prompt_embeds.to(dtype=dtype) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.tile((1, num_images_per_prompt, 1)) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="np", + ) + + negative_prompt_attention_mask = ms.Tensor.from_numpy(uncond_input.attention_mask) + negative_prompt_embeds = text_encoder( + ms.Tensor.from_numpy(uncond_input.input_ids), + attention_mask=negative_prompt_attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_attention_mask = negative_prompt_attention_mask.tile((num_images_per_prompt, 1)) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype) + + negative_prompt_embeds = negative_prompt_embeds.tile((1, num_images_per_prompt, 1)) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if ops.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="np") + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=ms.Tensor(safety_checker_input.pixel_values).to(dtype) + ) + + # Warning for safety checker operations here as it couldn't been done in construct() + if ops.any(has_nsfw_concept): + logger.warning( + "Potential NSFW content was detected in one or more images. A black image will be returned instead." + " Try again with a different prompt and/or seed." + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + prompt_embeds_2=None, + negative_prompt_embeds_2=None, + prompt_attention_mask_2=None, + negative_prompt_attention_mask_2=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" # noqa: E501 + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is None and prompt_embeds_2 is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds_2`. Cannot leave both `prompt` and `prompt_embeds_2` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if prompt_embeds_2 is not None and prompt_attention_mask_2 is None: + raise ValueError("Must provide `prompt_attention_mask_2` when specifying `prompt_embeds_2`.") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if negative_prompt_embeds_2 is not None and negative_prompt_attention_mask_2 is None: + raise ValueError( + "Must provide `negative_prompt_attention_mask_2` when specifying `negative_prompt_embeds_2`." + ) + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_embeds_2 is not None and negative_prompt_embeds_2 is not None: + if prompt_embeds_2.shape != negative_prompt_embeds_2.shape: + raise ValueError( + "`prompt_embeds_2` and `negative_prompt_embeds_2` must have the same shape when passed directly, but" + f" got: `prompt_embeds_2` {prompt_embeds_2.shape} != `negative_prompt_embeds_2`" + f" {negative_prompt_embeds_2.shape}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, dtype=dtype) + else: + latents = latents.to(dtype) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + # wtf? The above line changes the dtype of latents from fp16 to fp32, so we need a casting. + latents = latents.to(dtype=dtype) + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: Optional[float] = 0.0, + generator: Optional[Union[np.random.Generator, List[np.random.Generator]]] = None, + latents: Optional[ms.Tensor] = None, + prompt_embeds: Optional[ms.Tensor] = None, + prompt_embeds_2: Optional[ms.Tensor] = None, + negative_prompt_embeds: Optional[ms.Tensor] = None, + negative_prompt_embeds_2: Optional[ms.Tensor] = None, + prompt_attention_mask: Optional[ms.Tensor] = None, + prompt_attention_mask_2: Optional[ms.Tensor] = None, + negative_prompt_attention_mask: Optional[ms.Tensor] = None, + negative_prompt_attention_mask_2: Optional[ms.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = False, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + guidance_rescale: float = 0.0, + original_size: Optional[Tuple[int, int]] = (1024, 1024), + target_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + use_resolution_binning: bool = True, + ): + r""" + The call function to the pipeline for generation with HunyuanDiT. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter is modulated by `strength`. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`np.random.Generator` or `List[np.random.Generator]`, *optional*): + A [`np.random.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + prompt_embeds (`ms.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + prompt_embeds_2 (`ms.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`ms.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + negative_prompt_embeds_2 (`ms.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + prompt_attention_mask (`ms.Tensor`, *optional*): + Attention mask for the prompt. Required when `prompt_embeds` is passed directly. + prompt_attention_mask_2 (`ms.Tensor`, *optional*): + Attention mask for the prompt. Required when `prompt_embeds_2` is passed directly. + negative_prompt_attention_mask (`ms.Tensor`, *optional*): + Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly. + negative_prompt_attention_mask_2 (`ms.Tensor`, *optional*): + Attention mask for the negative prompt. Required when `negative_prompt_embeds_2` is passed directly. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `False`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback_on_step_end (`Callable[[int, int, Dict], None]`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A callback function or a list of callback functions to be called at the end of each denoising step. + callback_on_step_end_tensor_inputs (`List[str]`, *optional*): + A list of tensor inputs that should be passed to the callback function. If not defined, all tensor + inputs will be passed. + guidance_rescale (`float`, *optional*, defaults to 0.0): + Rescale the noise_cfg according to `guidance_rescale`. Based on findings of [Common Diffusion Noise + Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + original_size (`Tuple[int, int]`, *optional*, defaults to `(1024, 1024)`): + The original size of the image. Used to calculate the time ids. + target_size (`Tuple[int, int]`, *optional*): + The target size of the image. Used to calculate the time ids. + crops_coords_top_left (`Tuple[int, int]`, *optional*, defaults to `(0, 0)`): + The top left coordinates of the crop. Used to calculate the time ids. + use_resolution_binning (`bool`, *optional*, defaults to `True`): + Whether to use resolution binning or not. If `True`, the input resolution will be mapped to the closest + standard resolution. Supported resolutions are 1024x1024, 1280x1280, 1024x768, 1152x864, 1280x960, + 768x1024, 864x1152, 960x1280, 1280x768, and 768x1280. It is recommended to set this to `True`. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 0. default height and width + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + height = int((height // 16) * 16) + width = int((width // 16) * 16) + + if use_resolution_binning and (height, width) not in SUPPORTED_SHAPE: + width, height = map_to_standard_shapes(width, height) + height = int(height) + width = int(width) + logger.warning(f"Reshaped to (height, width)=({height}, {width}), Supported shapes are {SUPPORTED_SHAPE}") + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + prompt_embeds_2, + negative_prompt_embeds_2, + prompt_attention_mask_2, + negative_prompt_attention_mask_2, + callback_on_step_end_tensor_inputs, + ) + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + + ( + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt=prompt, + dtype=self.transformer.dtype, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + max_sequence_length=77, + text_encoder_index=0, + ) + ( + prompt_embeds_2, + negative_prompt_embeds_2, + prompt_attention_mask_2, + negative_prompt_attention_mask_2, + ) = self.encode_prompt( + prompt=prompt, + dtype=self.transformer.dtype, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds_2, + negative_prompt_embeds=negative_prompt_embeds_2, + prompt_attention_mask=prompt_attention_mask_2, + negative_prompt_attention_mask=negative_prompt_attention_mask_2, + max_sequence_length=256, + text_encoder_index=1, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7 create image_rotary_emb, style embedding & time ids + grid_height = height // 8 // self.transformer.config.patch_size + grid_width = width // 8 // self.transformer.config.patch_size + base_size = 512 // 8 // self.transformer.config.patch_size + grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size) + image_rotary_emb = get_2d_rotary_pos_embed( + self.transformer.inner_dim // self.transformer.num_heads, grid_crops_coords, (grid_height, grid_width) + ) + + style = ms.Tensor([0]) + + target_size = target_size or (height, width) + add_time_ids = list(original_size + target_size + crops_coords_top_left) + add_time_ids = ms.Tensor([add_time_ids], dtype=prompt_embeds.dtype) + + if self.do_classifier_free_guidance: + prompt_embeds = ops.cat([negative_prompt_embeds, prompt_embeds]) + prompt_attention_mask = ops.cat([negative_prompt_attention_mask, prompt_attention_mask]) + prompt_embeds_2 = ops.cat([negative_prompt_embeds_2, prompt_embeds_2]) + prompt_attention_mask_2 = ops.cat([negative_prompt_attention_mask_2, prompt_attention_mask_2]) + add_time_ids = ops.cat([add_time_ids] * 2, axis=0) + style = ops.cat([style] * 2, axis=0) + + add_time_ids = add_time_ids.to(dtype=prompt_embeds.dtype).tile((batch_size * num_images_per_prompt, 1)) + style = style.tile((batch_size * num_images_per_prompt,)) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = ops.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # expand scalar t to 1-D tensor to match the 1st dim of latent_model_input + # also consider the case where t is not a tensor (the most common case) + if ops.is_tensor(latent_model_input): + t_expand = t.broadcast_to((latent_model_input.shape[0],)).to(dtype=latent_model_input.dtype) + else: + t_expand = ms.Tensor([t] * latent_model_input.shape[0]).to(dtype=latent_model_input.dtype) + + # predict the noise residual + noise_pred = self.transformer( + latent_model_input, + t_expand, + encoder_hidden_states=prompt_embeds, + text_embedding_mask=prompt_attention_mask, + encoder_hidden_states_t5=prompt_embeds_2, + text_embedding_mask_t5=prompt_attention_mask_2, + image_meta_size=add_time_ids, + style=style, + image_rotary_emb=ms.mutable(image_rotary_emb), + return_dict=False, + )[0] + + noise_pred, _ = noise_pred.chunk(2, axis=1) + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + prompt_embeds_2 = callback_outputs.pop("prompt_embeds_2", prompt_embeds_2) + negative_prompt_embeds_2 = callback_outputs.pop( + "negative_prompt_embeds_2", negative_prompt_embeds_2 + ) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/mindone/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py b/mindone/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py index 652df6a63b..6ac83d338b 100644 --- a/mindone/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +++ b/mindone/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py @@ -30,6 +30,7 @@ from ...schedulers import DDIMScheduler from ...utils import BaseOutput, logging from ...utils.mindspore_utils import randn_tensor +from ...video_processor import VideoProcessor from ..pipeline_utils import DiffusionPipeline logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -41,9 +42,13 @@ >>> from mindone.diffusers import I2VGenXLPipeline >>> from mindone.diffusers.utils import export_to_gif, load_image - >>> pipeline = I2VGenXLPipeline.from_pretrained("ali-vilab/i2vgen-xl", mindspore_dtype=mindspore.float16, use_safetensors=True) + >>> pipeline = I2VGenXLPipeline.from_pretrained( + ... "ali-vilab/i2vgen-xl", mindspore_dtype=mindspore.float16, variant="fp16" + ... ) - >>> image_url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/i2vgen_xl_images/img_0009.png" + >>> image_url = ( + ... "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/i2vgen_xl_images/img_0009.png" + ... ) >>> image = load_image(image_url).convert("RGB") >>> image = image.resize((image.width // 2, image.height // 2)) @@ -64,36 +69,15 @@ """ -# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid -def tensor2vid(video: ms.Tensor, processor: "VaeImageProcessor", output_type: str = "np"): - batch_size, channels, num_frames, height, width = video.shape - outputs = [] - for batch_idx in range(batch_size): - batch_vid = video[batch_idx].permute(1, 0, 2, 3) - batch_output = processor.postprocess(batch_vid, output_type) - - outputs.append(batch_output) - - if output_type == "np": - outputs = np.stack(outputs) - - elif output_type == "pt": - outputs = ops.stack(outputs) - - elif not output_type == "pil": - raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']") - - return outputs - - @dataclass class I2VGenXLPipelineOutput(BaseOutput): r""" Output class for image-to-video pipeline. Args: - frames (`ms.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): - List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing denoised + frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape `(batch_size, num_frames, channels, height, width)` """ @@ -146,7 +130,7 @@ def __init__( ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) # `do_resize=False` as we do custom resizing. - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_resize=False) + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor, do_resize=False) @property def guidance_scale(self): @@ -329,7 +313,7 @@ def _encode_image(self, image, num_videos_per_prompt): dtype = next(self.image_encoder.get_parameters()).dtype if not isinstance(image, ms.Tensor): - image = self.image_processor.pil_to_numpy(image) + image = self.video_processor.pil_to_numpy(image) # Normalize the image with CLIP training stats. image = self.feature_extractor( @@ -534,7 +518,8 @@ def __call__( width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The width in pixels of the generated image. target_fps (`int`, *optional*): - Frames per second. The rate at which the generated images shall be exported to a video after generation. This is also used as a "micro-condition" while generation. # noqa: E501 + Frames per second. The rate at which the generated images shall be exported to a video after + generation. This is also used as a "micro-condition" while generation. num_frames (`int`, *optional*): The number of video frames to generate. num_inference_steps (`int`, *optional*): @@ -551,9 +536,9 @@ def __call__( num_videos_per_prompt (`int`, *optional*): The number of images to generate per prompt. decode_chunk_size (`int`, *optional*): - The number of frames to decode at a time. The higher the chunk size, the higher the temporal consistency - between frames, but also the higher the memory consumption. By default, the decoder will decode all frames at once - for maximal quality. Reduce `decode_chunk_size` to reduce memory usage. + The number of frames to decode at a time. The higher the chunk size, the higher the temporal + consistency between frames, but also the higher the memory consumption. By default, the decoder will + decode all frames at once for maximal quality. Reduce `decode_chunk_size` to reduce memory usage. generator (`np.random.Generator` or `List[np.random.Generator]`, *optional*): A [`np.random.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. @@ -640,7 +625,7 @@ def __call__( # 3.2.2 Image latents. resized_image = _center_crop_wide(image, (width, height)) - image = self.image_processor.preprocess(resized_image).to(dtype=image_embeddings.dtype) + image = self.video_processor.preprocess(resized_image).to(dtype=image_embeddings.dtype) image_latents = self.prepare_image_latents( image, num_frames=num_frames, @@ -726,7 +711,7 @@ def __call__( video = latents else: video_tensor = self.decode_latents(latents, decode_chunk_size=decode_chunk_size) - video = tensor2vid(video_tensor, self.image_processor, output_type=output_type) + video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) if not return_dict: return (video,) diff --git a/mindone/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py b/mindone/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py index a3150ef40b..b5dba5cfce 100644 --- a/mindone/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +++ b/mindone/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py @@ -59,6 +59,7 @@ def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, **kwargs, ): """ @@ -69,17 +70,21 @@ def retrieve_timesteps( scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): - The number of diffusion steps used when generating samples with a pre-trained model. If used, - `timesteps` must be `None`. + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. timesteps (`List[int]`, *optional*): - Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default - timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` - must be `None`. + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. Returns: - `Tuple[ms.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the - second element is the number of inference steps. + `Tuple[ms.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and + the second element is the number of inference steps. """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: @@ -90,6 +95,16 @@ def retrieve_timesteps( scheduler.set_timesteps(timesteps=timesteps, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, **kwargs) timesteps = scheduler.timesteps @@ -359,9 +374,10 @@ def encode_prompt( negative_prompt_embeds = negative_prompt_embeds.tile((1, num_images_per_prompt, 1)) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - if isinstance(self, LoraLoaderMixin): - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder is not None: + if isinstance(self, LoraLoaderMixin): + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) return prompt_embeds, negative_prompt_embeds @@ -438,20 +454,22 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt return latents # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding - def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=ms.float32): + def get_guidance_scale_embedding( + self, w: ms.Tensor, embedding_dim: int = 512, dtype: ms.Type = ms.float32 + ) -> ms.Tensor: """ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 Args: - timesteps (`ms.Tensor`): - generate embedding vectors at these timesteps + w (`ms.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. embedding_dim (`int`, *optional*, defaults to 512): - dimension of the embeddings to generate - dtype: - data type of the generated embeddings + Dimension of the embeddings to generate. + dtype (`ms.dtype`, *optional*, defaults to `ms.float32`): + Data type of the generated embeddings. Returns: - `ms.Tensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)` + `ms.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. """ assert len(w.shape) == 1 w = w * 1000.0 @@ -744,9 +762,10 @@ def __call__( else self.scheduler.config.original_inference_steps ) latent_timestep = timesteps[:1] - latents = self.prepare_latents( - image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, generator - ) + if latents is None: + latents = self.prepare_latents( + image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, generator + ) bs = batch_size * num_images_per_prompt # 6. Get Guidance Scale Embedding diff --git a/mindone/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py b/mindone/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py index 0bdad10f73..c2b7532878 100644 --- a/mindone/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +++ b/mindone/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py @@ -60,6 +60,7 @@ def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, **kwargs, ): """ @@ -70,17 +71,21 @@ def retrieve_timesteps( scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): - The number of diffusion steps used when generating samples with a pre-trained model. If used, - `timesteps` must be `None`. + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. timesteps (`List[int]`, *optional*): - Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default - timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` - must be `None`. + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. Returns: - `Tuple[ms.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the - second element is the number of inference steps. + `Tuple[ms.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and + the second element is the number of inference steps. """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: @@ -91,6 +96,16 @@ def retrieve_timesteps( scheduler.set_timesteps(timesteps=timesteps, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, **kwargs) timesteps = scheduler.timesteps @@ -345,9 +360,10 @@ def encode_prompt( negative_prompt_embeds = negative_prompt_embeds.tile((1, num_images_per_prompt, 1)) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - if isinstance(self, LoraLoaderMixin): - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder is not None: + if isinstance(self, LoraLoaderMixin): + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) return prompt_embeds, negative_prompt_embeds @@ -368,7 +384,12 @@ def run_safety_checker(self, image, dtype): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, generator, latents=None): - shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" @@ -384,20 +405,22 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype latents = latents * self.scheduler.init_noise_sigma return latents - def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=ms.float32): + def get_guidance_scale_embedding( + self, w: ms.Tensor, embedding_dim: int = 512, dtype: ms.Type = ms.float32 + ) -> ms.Tensor: """ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 Args: - timesteps (`ms.Tensor`): - generate embedding vectors at these timesteps + w (`ms.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. embedding_dim (`int`, *optional*, defaults to 512): - dimension of the embeddings to generate - dtype: - data type of the generated embeddings + Dimension of the embeddings to generate. + dtype (`ms.dtype`, *optional*, defaults to `ms.float32`): + Data type of the generated embeddings. Returns: - `ms.Tensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)` + `ms.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. """ assert len(w.shape) == 1 w = w * 1000.0 diff --git a/mindone/diffusers/pipelines/marigold/__init__.py b/mindone/diffusers/pipelines/marigold/__init__.py new file mode 100644 index 0000000000..70d85ea64c --- /dev/null +++ b/mindone/diffusers/pipelines/marigold/__init__.py @@ -0,0 +1,23 @@ +from typing import TYPE_CHECKING + +from ...utils import _LazyModule + +_import_structure = {} +_import_structure["marigold_image_processing"] = ["MarigoldImageProcessor"] +_import_structure["pipeline_marigold_depth"] = ["MarigoldDepthOutput", "MarigoldDepthPipeline"] +_import_structure["pipeline_marigold_normals"] = ["MarigoldNormalsOutput", "MarigoldNormalsPipeline"] + +if TYPE_CHECKING: + from .marigold_image_processing import MarigoldImageProcessor + from .pipeline_marigold_depth import MarigoldDepthOutput, MarigoldDepthPipeline + from .pipeline_marigold_normals import MarigoldNormalsOutput, MarigoldNormalsPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) diff --git a/mindone/diffusers/pipelines/marigold/marigold_image_processing.py b/mindone/diffusers/pipelines/marigold/marigold_image_processing.py new file mode 100644 index 0000000000..c2f3a917f7 --- /dev/null +++ b/mindone/diffusers/pipelines/marigold/marigold_image_processing.py @@ -0,0 +1,574 @@ +from typing import List, Optional, Tuple, Union + +import numpy as np +import PIL +from PIL import Image + +import mindspore as ms +from mindspore import ops + +from ... import ConfigMixin +from ...configuration_utils import register_to_config +from ...image_processor import PipelineImageInput +from ...utils import CONFIG_NAME, logging +from ...utils.import_utils import is_matplotlib_available + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class MarigoldImageProcessor(ConfigMixin): + config_name = CONFIG_NAME + + @register_to_config + def __init__( + self, + vae_scale_factor: int = 8, + do_normalize: bool = True, + do_range_check: bool = True, + ): + super().__init__() + + @staticmethod + def expand_tensor_or_array(images: Union[ms.Tensor, np.ndarray]) -> Union[ms.Tensor, np.ndarray]: + """ + Expand a tensor or array to a specified number of images. + """ + if isinstance(images, np.ndarray): + if images.ndim == 2: # [H,W] -> [1,H,W,1] + images = images[None, ..., None] + if images.ndim == 3: # [H,W,C] -> [1,H,W,C] + images = images[None] + elif isinstance(images, ms.Tensor): + if images.ndim == 2: # [H,W] -> [1,1,H,W] + images = images[None, None] + elif images.ndim == 3: # [1,H,W] -> [1,1,H,W] + images = images[None] + else: + raise ValueError(f"Unexpected input type: {type(images)}") + return images + + @staticmethod + def pt_to_numpy(images: ms.Tensor) -> np.ndarray: + """ + Convert a PyTorch tensor to a NumPy image. + """ + images = images.permute(0, 2, 3, 1).float().numpy() + return images + + @staticmethod + def numpy_to_pt(images: np.ndarray) -> ms.Tensor: + """ + Convert a NumPy image to a PyTorch tensor. + """ + if np.issubdtype(images.dtype, np.integer) and not np.issubdtype(images.dtype, np.unsignedinteger): + raise ValueError(f"Input image dtype={images.dtype} cannot be a signed integer.") + if np.issubdtype(images.dtype, np.complexfloating): + raise ValueError(f"Input image dtype={images.dtype} cannot be complex.") + if np.issubdtype(images.dtype, bool): + raise ValueError(f"Input image dtype={images.dtype} cannot be boolean.") + + images = ms.Tensor.from_numpy(images.transpose(0, 3, 1, 2)) + return images + + @staticmethod + def resize_antialias(image: ms.Tensor, size: Tuple[int, int], mode: str, is_aa: Optional[bool] = None) -> ms.Tensor: + if not ops.is_tensor(image): + raise ValueError(f"Invalid input type={type(image)}.") + if not ops.is_floating_point(image): + raise ValueError(f"Invalid input dtype={image.dtype}.") + if image.ndim != 4: + raise ValueError(f"Invalid input dimensions; shape={image.shape}.") + + antialias = is_aa and mode in ("bilinear", "bicubic") # noqa + # abandone argument `antialias=antialias` as MindSpore doesn't support + image = ops.interpolate(image, size, mode=mode) + + return image + + @staticmethod + def resize_to_max_edge(image: ms.Tensor, max_edge_sz: int, mode: str) -> ms.Tensor: + if not ops.is_tensor(image): + raise ValueError(f"Invalid input type={type(image)}.") + if not ops.is_floating_point(image): + raise ValueError(f"Invalid input dtype={image.dtype}.") + if image.ndim != 4: + raise ValueError(f"Invalid input dimensions; shape={image.shape}.") + + h, w = image.shape[-2:] + max_orig = max(h, w) + new_h = h * max_edge_sz // max_orig + new_w = w * max_edge_sz // max_orig + + if new_h == 0 or new_w == 0: + raise ValueError(f"Extreme aspect ratio of the input image: [{w} x {h}]") + + image = MarigoldImageProcessor.resize_antialias(image, (new_h, new_w), mode, is_aa=True) + + return image + + @staticmethod + def pad_image(image: ms.Tensor, align: int) -> Tuple[ms.Tensor, Tuple[int, int]]: + if not ops.is_tensor(image): + raise ValueError(f"Invalid input type={type(image)}.") + if not ops.is_floating_point(image): + raise ValueError(f"Invalid input dtype={image.dtype}.") + if image.ndim != 4: + raise ValueError(f"Invalid input dimensions; shape={image.shape}.") + + h, w = image.shape[-2:] + ph, pw = -h % align, -w % align + + # FIXME: replace with layers_compat.pad (PR#608) + image = ops.pad(image, (0, pw, 0, ph), mode="replicate") + + return image, (ph, pw) + + @staticmethod + def unpad_image(image: ms.Tensor, padding: Tuple[int, int]) -> ms.Tensor: + if not ops.is_tensor(image): + raise ValueError(f"Invalid input type={type(image)}.") + if not ops.is_floating_point(image): + raise ValueError(f"Invalid input dtype={image.dtype}.") + if image.ndim != 4: + raise ValueError(f"Invalid input dimensions; shape={image.shape}.") + + ph, pw = padding + uh = None if ph == 0 else -ph + uw = None if pw == 0 else -pw + + image = image[:, :, :uh, :uw] + + return image + + @staticmethod + def load_image_canonical( + image: Union[ms.Tensor, np.ndarray, Image.Image], + dtype: ms.Type = ms.float32, + ) -> Tuple[ms.Tensor, int]: + if isinstance(image, Image.Image): + image = np.array(image) + + image_dtype_max = None + if isinstance(image, (np.ndarray, ms.Tensor)): + image = MarigoldImageProcessor.expand_tensor_or_array(image) + if image.ndim != 4: + raise ValueError("Input image is not 2-, 3-, or 4-dimensional.") + if isinstance(image, np.ndarray): + if np.issubdtype(image.dtype, np.integer) and not np.issubdtype(image.dtype, np.unsignedinteger): + raise ValueError(f"Input image dtype={image.dtype} cannot be a signed integer.") + if np.issubdtype(image.dtype, np.complexfloating): + raise ValueError(f"Input image dtype={image.dtype} cannot be complex.") + if np.issubdtype(image.dtype, bool): + raise ValueError(f"Input image dtype={image.dtype} cannot be boolean.") + if np.issubdtype(image.dtype, np.unsignedinteger): + image_dtype_max = np.iinfo(image.dtype).max + image = image.astype(np.float32) # because torch does not have unsigned dtypes beyond ms.uint8 + image = MarigoldImageProcessor.numpy_to_pt(image) + + if ops.is_tensor(image) and not ops.is_floating_point(image) and image_dtype_max is None: + if image.dtype != ms.uint8: + raise ValueError(f"Image dtype={image.dtype} is not supported.") + image_dtype_max = 255 + + if not ops.is_tensor(image): + raise ValueError(f"Input type unsupported: {type(image)}.") + + if image.shape[1] == 1: + image = image.tile((1, 3, 1, 1)) # [N,1,H,W] -> [N,3,H,W] + if image.shape[1] != 3: + raise ValueError(f"Input image is not 1- or 3-channel: {image.shape}.") + + image = image.to(dtype=dtype) + + if image_dtype_max is not None: + image = image / image_dtype_max + + return image + + @staticmethod + def check_image_values_range(image: ms.Tensor) -> None: + if not ops.is_tensor(image): + raise ValueError(f"Invalid input type={type(image)}.") + if not ops.is_floating_point(image): + raise ValueError(f"Invalid input dtype={image.dtype}.") + if image.min().item() < 0.0 or image.max().item() > 1.0: + raise ValueError("Input image data is partially outside of the [0,1] range.") + + def preprocess( + self, + image: PipelineImageInput, + processing_resolution: Optional[int] = None, + resample_method_input: str = "bilinear", + dtype: ms.Type = ms.float32, + ): + if isinstance(image, list): + images = None + for i, img in enumerate(image): + img = self.load_image_canonical(img, dtype) # [N,3,H,W] + if images is None: + images = img + else: + if images.shape[2:] != img.shape[2:]: + raise ValueError( + f"Input image[{i}] has incompatible dimensions {img.shape[2:]} with the previous images " + f"{images.shape[2:]}" + ) + images = ops.cat((images, img), axis=0) + image = images + del images + else: + image = self.load_image_canonical(image, dtype) # [N,3,H,W] + + original_resolution = image.shape[2:] + + if self.config.do_range_check: + self.check_image_values_range(image) + + if self.config.do_normalize: + image = image * 2.0 - 1.0 + + if processing_resolution is not None and processing_resolution > 0: + image = self.resize_to_max_edge(image, processing_resolution, resample_method_input) # [N,3,PH,PW] + + image, padding = self.pad_image(image, self.config.vae_scale_factor) # [N,3,PPH,PPW] + + return image, padding, original_resolution + + @staticmethod + def colormap( + image: Union[np.ndarray, ms.Tensor], + cmap: str = "Spectral", + bytes: bool = False, + _force_method: Optional[str] = None, + ) -> Union[np.ndarray, ms.Tensor]: + """ + Converts a monochrome image into an RGB image by applying the specified colormap. This function mimics the + behavior of matplotlib.colormaps, but allows the user to use the most discriminative color maps ("Spectral", + "binary") without having to install or import matplotlib. For all other cases, the function will attempt to use + the native implementation. + + Args: + image: 2D tensor of values between 0 and 1, either as np.ndarray or ms.Tensor. + cmap: Colormap name. + bytes: Whether to return the output as uint8 or floating point image. + _force_method: + Can be used to specify whether to use the native implementation (`"matplotlib"`), the efficient custom + implementation of the select color maps (`"custom"`), or rely on autodetection (`None`, default). + + Returns: + An RGB-colorized tensor corresponding to the input image. + """ + if not (ops.is_tensor(image) or isinstance(image, np.ndarray)): + raise ValueError("Argument must be a numpy array or torch tensor.") + if _force_method not in (None, "matplotlib", "custom"): + raise ValueError("_force_method must be either `None`, `'matplotlib'` or `'custom'`.") + + supported_cmaps = { + "binary": [ + (1.0, 1.0, 1.0), + (0.0, 0.0, 0.0), + ], + "Spectral": [ # Taken from matplotlib/_cm.py + (0.61960784313725492, 0.003921568627450980, 0.25882352941176473), # 0.0 -> [0] + (0.83529411764705885, 0.24313725490196078, 0.30980392156862746), + (0.95686274509803926, 0.42745098039215684, 0.2627450980392157), + (0.99215686274509807, 0.68235294117647061, 0.38039215686274508), + (0.99607843137254903, 0.8784313725490196, 0.54509803921568623), + (1.0, 1.0, 0.74901960784313726), + (0.90196078431372551, 0.96078431372549022, 0.59607843137254901), + (0.6705882352941176, 0.8666666666666667, 0.64313725490196083), + (0.4, 0.76078431372549016, 0.6470588235294118), + (0.19607843137254902, 0.53333333333333333, 0.74117647058823533), + (0.36862745098039218, 0.30980392156862746, 0.63529411764705879), # 1.0 -> [K-1] + ], + } + + def method_matplotlib(image, cmap, bytes=False): + if is_matplotlib_available(): + import matplotlib + else: + return None + + arg_is_pt = ops.is_tensor(image) + if arg_is_pt: + image = image.numpy() + + if cmap not in matplotlib.colormaps: + raise ValueError( + f"Unexpected color map {cmap}; available options are: {', '.join(list(matplotlib.colormaps.keys()))}" + ) + + cmap = matplotlib.colormaps[cmap] + out = cmap(image, bytes=bytes) # [?,4] + out = out[..., :3] # [?,3] + + if arg_is_pt: + out = ms.Tensor(out) + + return out + + def method_custom(image, cmap, bytes=False): + arg_is_np = isinstance(image, np.ndarray) + if arg_is_np: + image = ms.Tensor(image) + if image.dtype == ms.uint8: + image = image.float() / 255 + else: + image = image.float() + + is_cmap_reversed = cmap.endswith("_r") + if is_cmap_reversed: + cmap = cmap[:-2] + + if cmap not in supported_cmaps: + raise ValueError( + f"Only {list(supported_cmaps.keys())} color maps are available without installing matplotlib." + ) + + cmap = supported_cmaps[cmap] + if is_cmap_reversed: + cmap = cmap[::-1] + cmap = ms.Tensor(cmap, dtype=ms.float32) # [K,3] + K = cmap.shape[0] + + pos = image.clamp(min=0, max=1) * (K - 1) + left = pos.long() + right = (left + 1).clamp(max=K - 1) + + d = (pos - left.float()).unsqueeze(-1) + left_colors = cmap[left] + right_colors = cmap[right] + + out = (1 - d) * left_colors + d * right_colors + + if bytes: + out = (out * 255).to(ms.uint8) + + if arg_is_np: + out = out.numpy() + + return out + + if _force_method is None and ops.is_tensor(image) and cmap == "Spectral": + return method_custom(image, cmap, bytes) + + out = None + if _force_method != "custom": + out = method_matplotlib(image, cmap, bytes) + + if _force_method == "matplotlib" and out is None: + raise ImportError("Make sure to install matplotlib if you want to use a color map other than 'Spectral'.") + + if out is None: + out = method_custom(image, cmap, bytes) + + return out + + @staticmethod + def visualize_depth( + depth: Union[ + PIL.Image.Image, + np.ndarray, + ms.Tensor, + List[PIL.Image.Image], + List[np.ndarray], + List[ms.Tensor], + ], + val_min: float = 0.0, + val_max: float = 1.0, + color_map: str = "Spectral", + ) -> Union[PIL.Image.Image, List[PIL.Image.Image]]: + """ + Visualizes depth maps, such as predictions of the `MarigoldDepthPipeline`. + + Args: + depth (`Union[PIL.Image.Image, np.ndarray, ms.Tensor, List[PIL.Image.Image], List[np.ndarray], + List[ms.Tensor]]`): Depth maps. + val_min (`float`, *optional*, defaults to `0.0`): Minimum value of the visualized depth range. + val_max (`float`, *optional*, defaults to `1.0`): Maximum value of the visualized depth range. + color_map (`str`, *optional*, defaults to `"Spectral"`): Color map used to convert a single-channel + depth prediction into colored representation. + + Returns: `PIL.Image.Image` or `List[PIL.Image.Image]` with depth maps visualization. + """ + if val_max <= val_min: + raise ValueError(f"Invalid values range: [{val_min}, {val_max}].") + + def visualize_depth_one(img, idx=None): + prefix = "Depth" + (f"[{idx}]" if idx else "") + if isinstance(img, PIL.Image.Image): + if img.mode != "I;16": + raise ValueError(f"{prefix}: invalid PIL mode={img.mode}.") + img = np.array(img).astype(np.float32) / (2**16 - 1) + if isinstance(img, np.ndarray) or ops.is_tensor(img): + if img.ndim != 2: + raise ValueError(f"{prefix}: unexpected shape={img.shape}.") + if isinstance(img, np.ndarray): + img = ms.Tensor.from_numpy(img) + if not ops.is_floating_point(img): + raise ValueError(f"{prefix}: unexected dtype={img.dtype}.") + else: + raise ValueError(f"{prefix}: unexpected type={type(img)}.") + if val_min != 0.0 or val_max != 1.0: + img = (img - val_min) / (val_max - val_min) + img = MarigoldImageProcessor.colormap(img, cmap=color_map, bytes=True) # [H,W,3] + img = PIL.Image.fromarray(img.numpy()) + return img + + if depth is None or isinstance(depth, list) and any(o is None for o in depth): + raise ValueError("Input depth is `None`") + if isinstance(depth, (np.ndarray, ms.Tensor)): + depth = MarigoldImageProcessor.expand_tensor_or_array(depth) + if isinstance(depth, np.ndarray): + depth = MarigoldImageProcessor.numpy_to_pt(depth) # [N,H,W,1] -> [N,1,H,W] + if not (depth.ndim == 4 and depth.shape[1] == 1): # [N,1,H,W] + raise ValueError(f"Unexpected input shape={depth.shape}, expecting [N,1,H,W].") + return [visualize_depth_one(img[0], idx) for idx, img in enumerate(depth)] + elif isinstance(depth, list): + return [visualize_depth_one(img, idx) for idx, img in enumerate(depth)] + else: + raise ValueError(f"Unexpected input type: {type(depth)}") + + @staticmethod + def export_depth_to_16bit_png( + depth: Union[np.ndarray, ms.Tensor, List[np.ndarray], List[ms.Tensor]], + val_min: float = 0.0, + val_max: float = 1.0, + ) -> Union[PIL.Image.Image, List[PIL.Image.Image]]: + def export_depth_to_16bit_png_one(img, idx=None): + prefix = "Depth" + (f"[{idx}]" if idx else "") + if not isinstance(img, np.ndarray) and not ops.is_tensor(img): + raise ValueError(f"{prefix}: unexpected type={type(img)}.") + if img.ndim != 2: + raise ValueError(f"{prefix}: unexpected shape={img.shape}.") + if ops.is_tensor(img): + img = img.numpy() + if not np.issubdtype(img.dtype, np.floating): + raise ValueError(f"{prefix}: unexected dtype={img.dtype}.") + if val_min != 0.0 or val_max != 1.0: + img = (img - val_min) / (val_max - val_min) + img = (img * (2**16 - 1)).astype(np.uint16) + img = PIL.Image.fromarray(img, mode="I;16") + return img + + if depth is None or isinstance(depth, list) and any(o is None for o in depth): + raise ValueError("Input depth is `None`") + if isinstance(depth, (np.ndarray, ms.Tensor)): + depth = MarigoldImageProcessor.expand_tensor_or_array(depth) + if isinstance(depth, np.ndarray): + depth = MarigoldImageProcessor.numpy_to_pt(depth) # [N,H,W,1] -> [N,1,H,W] + if not (depth.ndim == 4 and depth.shape[1] == 1): + raise ValueError(f"Unexpected input shape={depth.shape}, expecting [N,1,H,W].") + return [export_depth_to_16bit_png_one(img[0], idx) for idx, img in enumerate(depth)] + elif isinstance(depth, list): + return [export_depth_to_16bit_png_one(img, idx) for idx, img in enumerate(depth)] + else: + raise ValueError(f"Unexpected input type: {type(depth)}") + + @staticmethod + def visualize_normals( + normals: Union[ + np.ndarray, + ms.Tensor, + List[np.ndarray], + List[ms.Tensor], + ], + flip_x: bool = False, + flip_y: bool = False, + flip_z: bool = False, + ) -> Union[PIL.Image.Image, List[PIL.Image.Image]]: + """ + Visualizes surface normals, such as predictions of the `MarigoldNormalsPipeline`. + + Args: + normals (`Union[np.ndarray, ms.Tensor, List[np.ndarray], List[ms.Tensor]]`): + Surface normals. + flip_x (`bool`, *optional*, defaults to `False`): Flips the X axis of the normals frame of reference. + Default direction is right. + flip_y (`bool`, *optional*, defaults to `False`): Flips the Y axis of the normals frame of reference. + Default direction is top. + flip_z (`bool`, *optional*, defaults to `False`): Flips the Z axis of the normals frame of reference. + Default direction is facing the observer. + + Returns: `PIL.Image.Image` or `List[PIL.Image.Image]` with surface normals visualization. + """ + flip_vec = None + if any((flip_x, flip_y, flip_z)): + flip_vec = ms.Tensor( + [ + (-1) ** flip_x, + (-1) ** flip_y, + (-1) ** flip_z, + ], + dtype=ms.float32, + ) + + def visualize_normals_one(img, idx=None): + img = img.permute(1, 2, 0) + if flip_vec is not None: + img *= flip_vec + img = (img + 1.0) * 0.5 + img = (img * 255).to(dtype=ms.uint8).numpy() + img = PIL.Image.fromarray(img) + return img + + if normals is None or isinstance(normals, list) and any(o is None for o in normals): + raise ValueError("Input normals is `None`") + if isinstance(normals, (np.ndarray, ms.Tensor)): + normals = MarigoldImageProcessor.expand_tensor_or_array(normals) + if isinstance(normals, np.ndarray): + normals = MarigoldImageProcessor.numpy_to_pt(normals) # [N,3,H,W] + if not (normals.ndim == 4 and normals.shape[1] == 3): + raise ValueError(f"Unexpected input shape={normals.shape}, expecting [N,3,H,W].") + return [visualize_normals_one(img, idx) for idx, img in enumerate(normals)] + elif isinstance(normals, list): + return [visualize_normals_one(img, idx) for idx, img in enumerate(normals)] + else: + raise ValueError(f"Unexpected input type: {type(normals)}") + + @staticmethod + def visualize_uncertainty( + uncertainty: Union[ + np.ndarray, + ms.Tensor, + List[np.ndarray], + List[ms.Tensor], + ], + saturation_percentile=95, + ) -> Union[PIL.Image.Image, List[PIL.Image.Image]]: + """ + Visualizes dense uncertainties, such as produced by `MarigoldDepthPipeline` or `MarigoldNormalsPipeline`. + + Args: + uncertainty (`Union[np.ndarray, ms.Tensor, List[np.ndarray], List[ms.Tensor]]`): + Uncertainty maps. + saturation_percentile (`int`, *optional*, defaults to `95`): + Specifies the percentile uncertainty value visualized with maximum intensity. + + Returns: `PIL.Image.Image` or `List[PIL.Image.Image]` with uncertainty visualization. + """ + + def visualize_uncertainty_one(img, idx=None): + prefix = "Uncertainty" + (f"[{idx}]" if idx else "") + if img.min() < 0: + raise ValueError(f"{prefix}: unexected data range, min={img.min()}.") + img = img.squeeze(0).numpy() + saturation_value = np.percentile(img, saturation_percentile) + img = np.clip(img * 255 / saturation_value, 0, 255) + img = img.astype(np.uint8) + img = PIL.Image.fromarray(img) + return img + + if uncertainty is None or isinstance(uncertainty, list) and any(o is None for o in uncertainty): + raise ValueError("Input uncertainty is `None`") + if isinstance(uncertainty, (np.ndarray, ms.Tensor)): + uncertainty = MarigoldImageProcessor.expand_tensor_or_array(uncertainty) + if isinstance(uncertainty, np.ndarray): + uncertainty = MarigoldImageProcessor.numpy_to_pt(uncertainty) # [N,1,H,W] + if not (uncertainty.ndim == 4 and uncertainty.shape[1] == 1): + raise ValueError(f"Unexpected input shape={uncertainty.shape}, expecting [N,1,H,W].") + return [visualize_uncertainty_one(img, idx) for idx, img in enumerate(uncertainty)] + elif isinstance(uncertainty, list): + return [visualize_uncertainty_one(img, idx) for idx, img in enumerate(uncertainty)] + else: + raise ValueError(f"Unexpected input type: {type(uncertainty)}") diff --git a/mindone/diffusers/pipelines/marigold/pipeline_marigold_depth.py b/mindone/diffusers/pipelines/marigold/pipeline_marigold_depth.py new file mode 100644 index 0000000000..1a1e33a97b --- /dev/null +++ b/mindone/diffusers/pipelines/marigold/pipeline_marigold_depth.py @@ -0,0 +1,798 @@ +# Copyright 2024 Marigold authors, PRS ETH Zurich. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# More information and citation instructions are available on the +# Marigold project website: https://marigoldmonodepth.github.io +# -------------------------------------------------------------------------- +from dataclasses import dataclass +from functools import partial +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +from PIL import Image +from tqdm.auto import tqdm +from transformers import CLIPTokenizer + +import mindspore as ms +from mindspore import ops + +from ....transformers import CLIPTextModel +from ...image_processor import PipelineImageInput +from ...models import AutoencoderKL, UNet2DConditionModel +from ...schedulers import DDIMScheduler, LCMScheduler +from ...utils import BaseOutput, logging +from ...utils.import_utils import is_scipy_available +from ...utils.mindspore_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .marigold_image_processing import MarigoldImageProcessor + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ +Examples: +```py +>>> from mindone import diffusers +>>> import mindspore + +>>> pipe = diffusers.MarigoldDepthPipeline.from_pretrained( +... "prs-eth/marigold-depth-lcm-v1-0", variant="fp16", mindspore_dtype=mindspore.float16 +... ) + +>>> image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg") +>>> depth = pipe(image) + +>>> vis = pipe.image_processor.visualize_depth(depth[0]) +>>> vis[0].save("einstein_depth.png") + +>>> depth_16bit = pipe.image_processor.export_depth_to_16bit_png(depth[0]) +>>> depth_16bit[0].save("einstein_depth_16bit.png") +``` +""" + + +@dataclass +class MarigoldDepthOutput(BaseOutput): + """ + Output class for Marigold monocular depth prediction pipeline. + + Args: + prediction (`np.ndarray`, `ms.Tensor`): + Predicted depth maps with values in the range [0, 1]. The shape is always $numimages \times 1 \times height + \times width$, regardless of whether the images were passed as a 4D array or a list. + uncertainty (`None`, `np.ndarray`, `ms.Tensor`): + Uncertainty maps computed from the ensemble, with values in the range [0, 1]. The shape is $numimages + \times 1 \times height \times width$. + latent (`None`, `ms.Tensor`): + Latent features corresponding to the predictions, compatible with the `latents` argument of the pipeline. + The shape is $numimages * numensemble \times 4 \times latentheight \times latentwidth$. + """ + + prediction: Union[np.ndarray, ms.Tensor] + uncertainty: Union[None, np.ndarray, ms.Tensor] + latent: Union[None, ms.Tensor] + + +class MarigoldDepthPipeline(DiffusionPipeline): + """ + Pipeline for monocular depth estimation using the Marigold method: https://marigoldmonodepth.github.io. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + unet (`UNet2DConditionModel`): + Conditional U-Net to denoise the depth latent, conditioned on image latent. + vae (`AutoencoderKL`): + Variational Auto-Encoder (VAE) Model to encode and decode images and predictions to and from latent + representations. + scheduler (`DDIMScheduler` or `LCMScheduler`): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. + text_encoder (`CLIPTextModel`): + Text-encoder, for empty text embedding. + tokenizer (`CLIPTokenizer`): + CLIP tokenizer. + prediction_type (`str`, *optional*): + Type of predictions made by the model. + scale_invariant (`bool`, *optional*): + A model property specifying whether the predicted depth maps are scale-invariant. This value must be set in + the model config. When used together with the `shift_invariant=True` flag, the model is also called + "affine-invariant". NB: overriding this value is not supported. + shift_invariant (`bool`, *optional*): + A model property specifying whether the predicted depth maps are shift-invariant. This value must be set in + the model config. When used together with the `scale_invariant=True` flag, the model is also called + "affine-invariant". NB: overriding this value is not supported. + default_denoising_steps (`int`, *optional*): + The minimum number of denoising diffusion steps that are required to produce a prediction of reasonable + quality with the given model. This value must be set in the model config. When the pipeline is called + without explicitly setting `num_inference_steps`, the default value is used. This is required to ensure + reasonable results with various model flavors compatible with the pipeline, such as those relying on very + short denoising schedules (`LCMScheduler`) and those with full diffusion schedules (`DDIMScheduler`). + default_processing_resolution (`int`, *optional*): + The recommended value of the `processing_resolution` parameter of the pipeline. This value must be set in + the model config. When the pipeline is called without explicitly setting `processing_resolution`, the + default value is used. This is required to ensure reasonable results with various model flavors trained + with varying optimal processing resolution values. + """ + + model_cpu_offload_seq = "text_encoder->unet->vae" + supported_prediction_types = ("depth", "disparity") + + def __init__( + self, + unet: UNet2DConditionModel, + vae: AutoencoderKL, + scheduler: Union[DDIMScheduler, LCMScheduler], + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + prediction_type: Optional[str] = None, + scale_invariant: Optional[bool] = True, + shift_invariant: Optional[bool] = True, + default_denoising_steps: Optional[int] = None, + default_processing_resolution: Optional[int] = None, + ): + super().__init__() + + if prediction_type not in self.supported_prediction_types: + logger.warning( + f"Potentially unsupported `prediction_type='{prediction_type}'`; values supported by the pipeline: " + f"{self.supported_prediction_types}." + ) + + self.register_modules( + unet=unet, + vae=vae, + scheduler=scheduler, + text_encoder=text_encoder, + tokenizer=tokenizer, + ) + self.register_to_config( + prediction_type=prediction_type, + scale_invariant=scale_invariant, + shift_invariant=shift_invariant, + default_denoising_steps=default_denoising_steps, + default_processing_resolution=default_processing_resolution, + ) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + + self.scale_invariant = scale_invariant + self.shift_invariant = shift_invariant + self.default_denoising_steps = default_denoising_steps + self.default_processing_resolution = default_processing_resolution + + self.empty_text_embedding = None + + self.image_processor = MarigoldImageProcessor(vae_scale_factor=self.vae_scale_factor) + + def check_inputs( + self, + image: PipelineImageInput, + num_inference_steps: int, + ensemble_size: int, + processing_resolution: int, + resample_method_input: str, + resample_method_output: str, + batch_size: int, + ensembling_kwargs: Optional[Dict[str, Any]], + latents: Optional[ms.Tensor], + generator: Optional[Union[np.random.Generator, List[np.random.Generator]]], + output_type: str, + output_uncertainty: bool, + ) -> int: + if num_inference_steps is None: + raise ValueError("`num_inference_steps` is not specified and could not be resolved from the model config.") + if num_inference_steps < 1: + raise ValueError("`num_inference_steps` must be positive.") + if ensemble_size < 1: + raise ValueError("`ensemble_size` must be positive.") + if ensemble_size == 2: + logger.warning( + "`ensemble_size` == 2 results are similar to no ensembling (1); " + "consider increasing the value to at least 3." + ) + if ensemble_size > 1 and (self.scale_invariant or self.shift_invariant) and not is_scipy_available(): + raise ImportError("Make sure to install scipy if you want to use ensembling.") + if ensemble_size == 1 and output_uncertainty: + raise ValueError( + "Computing uncertainty by setting `output_uncertainty=True` also requires setting `ensemble_size` " + "greater than 1." + ) + if processing_resolution is None: + raise ValueError( + "`processing_resolution` is not specified and could not be resolved from the model config." + ) + if processing_resolution < 0: + raise ValueError( + "`processing_resolution` must be non-negative: 0 for native resolution, or any positive value for " + "downsampled processing." + ) + if processing_resolution % self.vae_scale_factor != 0: + raise ValueError(f"`processing_resolution` must be a multiple of {self.vae_scale_factor}.") + if resample_method_input not in ("nearest", "nearest-exact", "bilinear", "bicubic", "area"): + raise ValueError( + "`resample_method_input` takes string values compatible with PIL library: " + "nearest, nearest-exact, bilinear, bicubic, area." + ) + if resample_method_output not in ("nearest", "nearest-exact", "bilinear", "bicubic", "area"): + raise ValueError( + "`resample_method_output` takes string values compatible with PIL library: " + "nearest, nearest-exact, bilinear, bicubic, area." + ) + if batch_size < 1: + raise ValueError("`batch_size` must be positive.") + if output_type not in ["pt", "np"]: + raise ValueError("`output_type` must be one of `pt` or `np`.") + if latents is not None and generator is not None: + raise ValueError("`latents` and `generator` cannot be used together.") + if ensembling_kwargs is not None: + if not isinstance(ensembling_kwargs, dict): + raise ValueError("`ensembling_kwargs` must be a dictionary.") + if "reduction" in ensembling_kwargs and ensembling_kwargs["reduction"] not in ("mean", "median"): + raise ValueError("`ensembling_kwargs['reduction']` can be either `'mean'` or `'median'`.") + + # image checks + num_images = 0 + W, H = None, None + if not isinstance(image, list): + image = [image] + for i, img in enumerate(image): + if isinstance(img, np.ndarray) or ops.is_tensor(img): + if img.ndim not in (2, 3, 4): + raise ValueError(f"`image[{i}]` has unsupported dimensions or shape: {img.shape}.") + H_i, W_i = img.shape[-2:] + N_i = 1 + if img.ndim == 4: + N_i = img.shape[0] + elif isinstance(img, Image.Image): + W_i, H_i = img.size + N_i = 1 + else: + raise ValueError(f"Unsupported `image[{i}]` type: {type(img)}.") + if W is None: + W, H = W_i, H_i + elif (W, H) != (W_i, H_i): + raise ValueError( + f"Input `image[{i}]` has incompatible dimensions {(W_i, H_i)} with the previous images {(W, H)}" + ) + num_images += N_i + + # latents checks + if latents is not None: + if not ops.is_tensor(latents): + raise ValueError("`latents` must be a ms.Tensor.") + if latents.ndim != 4: + raise ValueError(f"`latents` has unsupported dimensions or shape: {latents.shape}.") + + if processing_resolution > 0: + max_orig = max(H, W) + new_H = H * processing_resolution // max_orig + new_W = W * processing_resolution // max_orig + if new_H == 0 or new_W == 0: + raise ValueError(f"Extreme aspect ratio of the input image: [{W} x {H}]") + W, H = new_W, new_H + w = (W + self.vae_scale_factor - 1) // self.vae_scale_factor + h = (H + self.vae_scale_factor - 1) // self.vae_scale_factor + shape_expected = (num_images * ensemble_size, self.vae.config.latent_channels, h, w) + + if latents.shape != shape_expected: + raise ValueError(f"`latents` has unexpected shape={latents.shape} expected={shape_expected}.") + + # generator checks + if generator is not None: + if isinstance(generator, list): + if len(generator) != num_images * ensemble_size: + raise ValueError( + "The number of generators must match the total number of ensemble members for all input images." + ) + elif not isinstance(generator, np.random.Generator): + raise ValueError(f"Unsupported generator type: {type(generator)}.") + + return num_images + + def progress_bar(self, iterable=None, total=None, desc=None, leave=True): + if not hasattr(self, "_progress_bar_config"): + self._progress_bar_config = {} + elif not isinstance(self._progress_bar_config, dict): + raise ValueError( + f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." + ) + + progress_bar_config = dict(**self._progress_bar_config) + progress_bar_config["desc"] = progress_bar_config.get("desc", desc) + progress_bar_config["leave"] = progress_bar_config.get("leave", leave) + if iterable is not None: + return tqdm(iterable, **progress_bar_config) + elif total is not None: + return tqdm(total=total, **progress_bar_config) + else: + raise ValueError("Either `total` or `iterable` has to be defined.") + + def __call__( + self, + image: PipelineImageInput, + num_inference_steps: Optional[int] = None, + ensemble_size: int = 1, + processing_resolution: Optional[int] = None, + match_input_resolution: bool = True, + resample_method_input: str = "bilinear", + resample_method_output: str = "bilinear", + batch_size: int = 1, + ensembling_kwargs: Optional[Dict[str, Any]] = None, + latents: Optional[Union[ms.Tensor, List[ms.Tensor]]] = None, + generator: Optional[Union[np.random.Generator, List[np.random.Generator]]] = None, + output_type: str = "np", + output_uncertainty: bool = False, + output_latent: bool = False, + return_dict: bool = False, + ): + """ + Function invoked when calling the pipeline. + + Args: + image (`PIL.Image.Image`, `np.ndarray`, `ms.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`), + `List[ms.Tensor]`: An input image or images used as an input for the depth estimation task. For + arrays and tensors, the expected value range is between `[0, 1]`. Passing a batch of images is possible + by providing a four-dimensional array or a tensor. Additionally, a list of images of two- or + three-dimensional arrays or tensors can be passed. In the latter case, all list elements must have the + same width and height. + num_inference_steps (`int`, *optional*, defaults to `None`): + Number of denoising diffusion steps during inference. The default value `None` results in automatic + selection. The number of steps should be at least 10 with the full Marigold models, and between 1 and 4 + for Marigold-LCM models. + ensemble_size (`int`, defaults to `1`): + Number of ensemble predictions. Recommended values are 5 and higher for better precision, or 1 for + faster inference. + processing_resolution (`int`, *optional*, defaults to `None`): + Effective processing resolution. When set to `0`, matches the larger input image dimension. This + produces crisper predictions, but may also lead to the overall loss of global context. The default + value `None` resolves to the optimal value from the model config. + match_input_resolution (`bool`, *optional*, defaults to `True`): + When enabled, the output prediction is resized to match the input dimensions. When disabled, the longer + side of the output will equal to `processing_resolution`. + resample_method_input (`str`, *optional*, defaults to `"bilinear"`): + Resampling method used to resize input images to `processing_resolution`. The accepted values are: + `"nearest"`, `"nearest-exact"`, `"bilinear"`, `"bicubic"`, or `"area"`. + resample_method_output (`str`, *optional*, defaults to `"bilinear"`): + Resampling method used to resize output predictions to match the input resolution. The accepted values + are `"nearest"`, `"nearest-exact"`, `"bilinear"`, `"bicubic"`, or `"area"`. + batch_size (`int`, *optional*, defaults to `1`): + Batch size; only matters when setting `ensemble_size` or passing a tensor of images. + ensembling_kwargs (`dict`, *optional*, defaults to `None`) + Extra dictionary with arguments for precise ensembling control. The following options are available: + - reduction (`str`, *optional*, defaults to `"median"`): Defines the ensembling function applied in + every pixel location, can be either `"median"` or `"mean"`. + - regularizer_strength (`float`, *optional*, defaults to `0.02`): Strength of the regularizer that + pulls the aligned predictions to the unit range from 0 to 1. + - max_iter (`int`, *optional*, defaults to `2`): Maximum number of the alignment solver steps. Refer to + `scipy.optimize.minimize` function, `options` argument. + - tol (`float`, *optional*, defaults to `1e-3`): Alignment solver tolerance. The solver stops when the + tolerance is reached. + - max_res (`int`, *optional*, defaults to `None`): Resolution at which the alignment is performed; + `None` matches the `processing_resolution`. + latents (`ms.Tensor`, or `List[ms.Tensor]`, *optional*, defaults to `None`): + Latent noise tensors to replace the random initialization. These can be taken from the previous + function call's output. + generator (`np.random.Generator`, or `List[np.random.Generator]`, *optional*, defaults to `None`): + Random number generator object to ensure reproducibility. + output_type (`str`, *optional*, defaults to `"np"`): + Preferred format of the output's `prediction` and the optional `uncertainty` fields. The accepted + values are: `"np"` (numpy array) or `"pt"` (torch tensor). + output_uncertainty (`bool`, *optional*, defaults to `False`): + When enabled, the output's `uncertainty` field contains the predictive uncertainty map, provided that + the `ensemble_size` argument is set to a value above 2. + output_latent (`bool`, *optional*, defaults to `False`): + When enabled, the output's `latent` field contains the latent codes corresponding to the predictions + within the ensemble. These codes can be saved, modified, and used for subsequent calls with the + `latents` argument. + return_dict (`bool`, *optional*, defaults to `False`): + Whether or not to return a [`~pipelines.marigold.MarigoldDepthOutput`] instead of a plain tuple. + + Examples: + + Returns: + [`~pipelines.marigold.MarigoldDepthOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.marigold.MarigoldDepthOutput`] is returned, otherwise a + `tuple` is returned where the first element is the prediction, the second element is the uncertainty + (or `None`), and the third is the latent (or `None`). + """ + + # 0. Resolving variables. + dtype = self.dtype + + # Model-specific optimal default values leading to fast and reasonable results. + if num_inference_steps is None: + num_inference_steps = self.default_denoising_steps + if processing_resolution is None: + processing_resolution = self.default_processing_resolution + + # 1. Check inputs. + num_images = self.check_inputs( + image, + num_inference_steps, + ensemble_size, + processing_resolution, + resample_method_input, + resample_method_output, + batch_size, + ensembling_kwargs, + latents, + generator, + output_type, + output_uncertainty, + ) + + # 2. Prepare empty text conditioning. + # Model invocation: self.tokenizer, self.text_encoder. + if self.empty_text_embedding is None: + prompt = "" + text_inputs = self.tokenizer( + prompt, + padding="do_not_pad", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="np", + ) + text_input_ids = ms.Tensor.from_numpy(text_inputs.input_ids) + self.empty_text_embedding = self.text_encoder(text_input_ids)[0] # [1,2,1024] + + # 3. Preprocess input images. This function loads input image or images of compatible dimensions `(H, W)`, + # optionally downsamples them to the `processing_resolution` `(PH, PW)`, where + # `max(PH, PW) == processing_resolution`, and pads the dimensions to `(PPH, PPW)` such that these values are + # divisible by the latent space downscaling factor (typically 8 in Stable Diffusion). The default value `None` + # of `processing_resolution` resolves to the optimal value from the model config. It is a recommended mode of + # operation and leads to the most reasonable results. Using the native image resolution or any other processing + # resolution can lead to loss of either fine details or global context in the output predictions. + image, padding, original_resolution = self.image_processor.preprocess( + image, processing_resolution, resample_method_input, dtype + ) # [N,3,PPH,PPW] + + # 4. Encode input image into latent space. At this step, each of the `N` input images is represented with `E` + # ensemble members. Each ensemble member is an independent diffused prediction, just initialized independently. + # Latents of each such predictions across all input images and all ensemble members are represented in the + # `pred_latent` variable. The variable `image_latent` is of the same shape: it contains each input image encoded + # into latent space and replicated `E` times. The latents can be either generated (see `generator` to ensure + # reproducibility), or passed explicitly via the `latents` argument. The latter can be set outside the pipeline + # code. For example, in the Marigold-LCM video processing demo, the latents initialization of a frame is taken + # as a convex combination of the latents output of the pipeline for the previous frame and a newly-sampled + # noise. This behavior can be achieved by setting the `output_latent` argument to `True`. The latent space + # dimensions are `(h, w)`. Encoding into latent space happens in batches of size `batch_size`. + # Model invocation: self.vae.encoder. + image_latent, pred_latent = self.prepare_latents( + image, latents, generator, ensemble_size, batch_size + ) # [N*E,4,h,w], [N*E,4,h,w] + + del image + + batch_empty_text_embedding = self.empty_text_embedding.to(dtype=dtype).tile((batch_size, 1, 1)) # [B,1024,2] + + # 5. Process the denoising loop. All `N * E` latents are processed sequentially in batches of size `batch_size`. + # The unet model takes concatenated latent spaces of the input image and the predicted modality as an input, and + # outputs noise for the predicted modality's latent space. The number of denoising diffusion steps is defined by + # `num_inference_steps`. It is either set directly, or resolves to the optimal value specific to the loaded + # model. + # Model invocation: self.unet. + pred_latents = [] + + for i in self.progress_bar( + range(0, num_images * ensemble_size, batch_size), leave=True, desc="Marigold predictions..." + ): + batch_image_latent = image_latent[i : i + batch_size] # [B,4,h,w] + batch_pred_latent = pred_latent[i : i + batch_size] # [B,4,h,w] + effective_batch_size = batch_image_latent.shape[0] + text = batch_empty_text_embedding[:effective_batch_size] # [B,2,1024] + + self.scheduler.set_timesteps(num_inference_steps) + for t in self.progress_bar(self.scheduler.timesteps, leave=False, desc="Diffusion steps..."): + batch_latent = ops.cat([batch_image_latent, batch_pred_latent], axis=1) # [B,8,h,w] + noise = self.unet(batch_latent, t, encoder_hidden_states=text, return_dict=False)[0] # [B,4,h,w] + batch_pred_latent = self.scheduler.step(noise, t, batch_pred_latent, generator=generator)[ + 0 + ] # [B,4,h,w] + + pred_latents.append(batch_pred_latent) + + pred_latent = ops.cat(pred_latents, axis=0) # [N*E,4,h,w] + + del ( + pred_latents, + image_latent, + batch_empty_text_embedding, + batch_image_latent, + batch_pred_latent, + text, + batch_latent, + noise, + ) + + # 6. Decode predictions from latent into pixel space. The resulting `N * E` predictions have shape `(PPH, PPW)`, + # which requires slight postprocessing. Decoding into pixel space happens in batches of size `batch_size`. + # Model invocation: self.vae.decoder. + prediction = ops.cat( + [ + self.decode_prediction(pred_latent[i : i + batch_size]) + for i in range(0, pred_latent.shape[0], batch_size) + ], + axis=0, + ) # [N*E,1,PPH,PPW] + + if not output_latent: + pred_latent = None + + # 7. Remove padding. The output shape is (PH, PW). + prediction = self.image_processor.unpad_image(prediction, padding) # [N*E,1,PH,PW] + + # 8. Ensemble and compute uncertainty (when `output_uncertainty` is set). This code treats each of the `N` + # groups of `E` ensemble predictions independently. For each group it computes an ensembled prediction of shape + # `(PH, PW)` and an optional uncertainty map of the same dimensions. After computing this pair of outputs for + # each group independently, it stacks them respectively into batches of `N` almost final predictions and + # uncertainty maps. + uncertainty = None + if ensemble_size > 1: + prediction = prediction.reshape(num_images, ensemble_size, *prediction.shape[1:]) # [N,E,1,PH,PW] + prediction = [ + self.ensemble_depth( + prediction[i], + self.scale_invariant, + self.shift_invariant, + output_uncertainty, + **(ensembling_kwargs or {}), + ) + for i in range(num_images) + ] # [ [[1,1,PH,PW], [1,1,PH,PW]], ... ] + prediction, uncertainty = zip(*prediction) # [[1,1,PH,PW], ... ], [[1,1,PH,PW], ... ] + prediction = ops.cat(prediction, axis=0) # [N,1,PH,PW] + if output_uncertainty: + uncertainty = ops.cat(uncertainty, axis=0) # [N,1,PH,PW] + else: + uncertainty = None + + # 9. If `match_input_resolution` is set, the output prediction and the uncertainty are upsampled to match the + # input resolution `(H, W)`. This step may introduce upsampling artifacts, and therefore can be disabled. + # Depending on the downstream use-case, upsampling can be also chosen based on the tolerated artifacts by + # setting the `resample_method_output` parameter (e.g., to `"nearest"`). + if match_input_resolution: + prediction = self.image_processor.resize_antialias( + prediction, original_resolution, resample_method_output, is_aa=False + ) # [N,1,H,W] + if uncertainty is not None and output_uncertainty: + uncertainty = self.image_processor.resize_antialias( + uncertainty, original_resolution, resample_method_output, is_aa=False + ) # [N,1,H,W] + + # 10. Prepare the final outputs. + if output_type == "np": + prediction = self.image_processor.pt_to_numpy(prediction) # [N,H,W,1] + if uncertainty is not None and output_uncertainty: + uncertainty = self.image_processor.pt_to_numpy(uncertainty) # [N,H,W,1] + + if not return_dict: + return (prediction, uncertainty, pred_latent) + + return MarigoldDepthOutput( + prediction=prediction, + uncertainty=uncertainty, + latent=pred_latent, + ) + + def prepare_latents( + self, + image: ms.Tensor, + latents: Optional[ms.Tensor], + generator: Optional[np.random.Generator], + ensemble_size: int, + batch_size: int, + ) -> Tuple[ms.Tensor, ms.Tensor]: + def retrieve_latents(encoder_output): + assert ops.is_tensor( + encoder_output + ), "Could not access latents of provided encoder_output which is not a tensor" + if hasattr(self.vae, "diag_gauss_dist"): + return self.vae.diag_gauss_dist.mode(encoder_output) + else: + return encoder_output + + image_latent = ops.cat( + [ + retrieve_latents(self.vae.encode(image[i : i + batch_size])[0]) + for i in range(0, image.shape[0], batch_size) + ], + axis=0, + ) # [N,4,h,w] + image_latent = image_latent * self.vae.config.scaling_factor + image_latent = image_latent.repeat_interleave(ensemble_size, dim=0) # [N*E,4,h,w] + + pred_latent = latents + if pred_latent is None: + pred_latent = randn_tensor( + image_latent.shape, + generator=generator, + dtype=image_latent.dtype, + ) # [N*E,4,h,w] + + return image_latent, pred_latent + + def decode_prediction(self, pred_latent: ms.Tensor) -> ms.Tensor: + if pred_latent.ndim != 4 or pred_latent.shape[1] != self.vae.config.latent_channels: + raise ValueError( + f"Expecting 4D tensor of shape [B,{self.vae.config.latent_channels},H,W]; got {pred_latent.shape}." + ) + + prediction = self.vae.decode(pred_latent / self.vae.config.scaling_factor, return_dict=False)[0] # [B,3,H,W] + + prediction = prediction.mean(axis=1, keep_dims=True) # [B,1,H,W] + prediction = ops.clip(prediction, -1.0, 1.0) # [B,1,H,W] + prediction = (prediction + 1.0) / 2.0 + + return prediction # [B,1,H,W] + + @staticmethod + def ensemble_depth( + depth: ms.Tensor, + scale_invariant: bool = True, + shift_invariant: bool = True, + output_uncertainty: bool = False, + reduction: str = "median", + regularizer_strength: float = 0.02, + max_iter: int = 2, + tol: float = 1e-3, + max_res: int = 1024, + ) -> Tuple[ms.Tensor, Optional[ms.Tensor]]: + """ + Ensembles the depth maps represented by the `depth` tensor with expected shape `(B, 1, H, W)`, where B is the + number of ensemble members for a given prediction of size `(H x W)`. Even though the function is designed for + depth maps, it can also be used with disparity maps as long as the input tensor values are non-negative. The + alignment happens when the predictions have one or more degrees of freedom, that is when they are either + affine-invariant (`scale_invariant=True` and `shift_invariant=True`), or just scale-invariant (only + `scale_invariant=True`). For absolute predictions (`scale_invariant=False` and `shift_invariant=False`) + alignment is skipped and only ensembling is performed. + + Args: + depth (`ms.Tensor`): + Input ensemble depth maps. + scale_invariant (`bool`, *optional*, defaults to `True`): + Whether to treat predictions as scale-invariant. + shift_invariant (`bool`, *optional*, defaults to `True`): + Whether to treat predictions as shift-invariant. + output_uncertainty (`bool`, *optional*, defaults to `False`): + Whether to output uncertainty map. + reduction (`str`, *optional*, defaults to `"median"`): + Reduction method used to ensemble aligned predictions. The accepted values are: `"mean"` and + `"median"`. + regularizer_strength (`float`, *optional*, defaults to `0.02`): + Strength of the regularizer that pulls the aligned predictions to the unit range from 0 to 1. + max_iter (`int`, *optional*, defaults to `2`): + Maximum number of the alignment solver steps. Refer to `scipy.optimize.minimize` function, `options` + argument. + tol (`float`, *optional*, defaults to `1e-3`): + Alignment solver tolerance. The solver stops when the tolerance is reached. + max_res (`int`, *optional*, defaults to `1024`): + Resolution at which the alignment is performed; `None` matches the `processing_resolution`. + Returns: + A tensor of aligned and ensembled depth maps and optionally a tensor of uncertainties of the same shape: + `(1, 1, H, W)`. + """ + if depth.ndim != 4 or depth.shape[1] != 1: + raise ValueError(f"Expecting 4D tensor of shape [B,1,H,W]; got {depth.shape}.") + if reduction not in ("mean", "median"): + raise ValueError(f"Unrecognized reduction method: {reduction}.") + if not scale_invariant and shift_invariant: + raise ValueError("Pure shift-invariant ensembling is not supported.") + + def init_param(depth: ms.Tensor): + init_min = depth.reshape(ensemble_size, -1).min(axis=1) + init_max = depth.reshape(ensemble_size, -1).max(axis=1) + + if scale_invariant and shift_invariant: + init_s = 1.0 / (init_max - init_min).clamp(min=1e-6) + init_t = -init_s * init_min + param = ops.cat((init_s, init_t)).numpy() + elif scale_invariant: + init_s = 1.0 / init_max.clamp(min=1e-6) + param = init_s.numpy() + else: + raise ValueError("Unrecognized alignment.") + + return param + + def align(depth: ms.Tensor, param: np.ndarray) -> ms.Tensor: + if scale_invariant and shift_invariant: + s, t = np.split(param, 2) + s = ms.Tensor.from_numpy(s).to(depth.dtype).view(ensemble_size, 1, 1, 1) + t = ms.Tensor.from_numpy(t).to(depth.dtype).view(ensemble_size, 1, 1, 1) + out = depth * s + t + elif scale_invariant: + s = ms.Tensor.from_numpy(param).to(depth.dtype).view(ensemble_size, 1, 1, 1) + out = depth * s + else: + raise ValueError("Unrecognized alignment.") + return out + + def ensemble( + depth_aligned: ms.Tensor, return_uncertainty: bool = False + ) -> Tuple[ms.Tensor, Optional[ms.Tensor]]: + uncertainty = None + if reduction == "mean": + prediction = ops.mean(depth_aligned, axis=0, keep_dims=True) + if return_uncertainty: + uncertainty = ops.std(depth_aligned, axis=0, keepdims=True) + elif reduction == "median": + # ops.median has two return values and does not supported some data-type + prediction = ops.median(depth_aligned.float(), axis=0, keepdims=True)[0] + prediction = prediction.to(depth_aligned.dtype) + if return_uncertainty: + uncertainty = ops.median(ops.abs(depth_aligned - prediction).float(), axis=0, keepdims=True)[0] + uncertainty = uncertainty.to(depth_aligned.dtype) + else: + raise ValueError(f"Unrecognized reduction method: {reduction}.") + return prediction, uncertainty + + def cost_fn(param: np.ndarray, depth: ms.Tensor) -> float: + cost = 0.0 + depth_aligned = align(depth, param) + + for i, j in ops.combinations(ops.arange(ensemble_size)): + diff = depth_aligned[i] - depth_aligned[j] + cost += (diff**2).mean().sqrt().item() + + if regularizer_strength > 0: + prediction, _ = ensemble(depth_aligned, return_uncertainty=False) + err_near = (0.0 - prediction.min()).abs().item() + err_far = (1.0 - prediction.max()).abs().item() + cost += (err_near + err_far) * regularizer_strength + + return cost + + def compute_param(depth: ms.Tensor): + import scipy + + depth_to_align = depth.to(ms.float32) + if max_res is not None and max(depth_to_align.shape[2:]) > max_res: + depth_to_align = MarigoldImageProcessor.resize_to_max_edge(depth_to_align, max_res, "nearest-exact") + + param = init_param(depth_to_align) + + res = scipy.optimize.minimize( + partial(cost_fn, depth=depth_to_align), + param, + method="BFGS", + tol=tol, + options={"maxiter": max_iter, "disp": False}, + ) + + return res.x + + requires_aligning = scale_invariant or shift_invariant + ensemble_size = depth.shape[0] + + if requires_aligning: + param = compute_param(depth) + depth = align(depth, param) + + depth, uncertainty = ensemble(depth, return_uncertainty=output_uncertainty) + + depth_max = depth.max() + if scale_invariant and shift_invariant: + depth_min = depth.min() + elif scale_invariant: + depth_min = 0 + else: + raise ValueError("Unrecognized alignment.") + depth_range = (depth_max - depth_min).clamp(min=1e-6) + depth = (depth - depth_min) / depth_range + if output_uncertainty: + uncertainty /= depth_range + + return depth, uncertainty # [1,1,H,W], [1,1,H,W] diff --git a/mindone/diffusers/pipelines/marigold/pipeline_marigold_normals.py b/mindone/diffusers/pipelines/marigold/pipeline_marigold_normals.py new file mode 100644 index 0000000000..f40ada708a --- /dev/null +++ b/mindone/diffusers/pipelines/marigold/pipeline_marigold_normals.py @@ -0,0 +1,672 @@ +# Copyright 2024 Marigold authors, PRS ETH Zurich. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# More information and citation instructions are available on the +# Marigold project website: https://marigoldmonodepth.github.io +# -------------------------------------------------------------------------- +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +from PIL import Image +from tqdm.auto import tqdm +from transformers import CLIPTokenizer + +import mindspore as ms +from mindspore import ops + +from ....transformers import CLIPTextModel +from ...image_processor import PipelineImageInput +from ...models import AutoencoderKL, UNet2DConditionModel +from ...schedulers import DDIMScheduler, LCMScheduler +from ...utils import BaseOutput, logging +from ...utils.mindspore_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .marigold_image_processing import MarigoldImageProcessor + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ +Examples: +```py +>>> from mindone import diffusers +>>> import mindspore + +>>> pipe = diffusers.MarigoldNormalsPipeline.from_pretrained( +... "prs-eth/marigold-normals-lcm-v0-1", variant="fp16", mindspore_dtype=mindspore.float16 +... ) + +>>> image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg") +>>> normals = pipe(image) + +>>> vis = pipe.image_processor.visualize_normals(normals[0]) +>>> vis[0].save("einstein_normals.png") +``` +""" + + +@dataclass +class MarigoldNormalsOutput(BaseOutput): + """ + Output class for Marigold monocular normals prediction pipeline. + + Args: + prediction (`np.ndarray`, `ms.Tensor`): + Predicted normals with values in the range [-1, 1]. The shape is always $numimages \times 3 \times height + \times width$, regardless of whether the images were passed as a 4D array or a list. + uncertainty (`None`, `np.ndarray`, `ms.Tensor`): + Uncertainty maps computed from the ensemble, with values in the range [0, 1]. The shape is $numimages + \times 1 \times height \times width$. + latent (`None`, `ms.Tensor`): + Latent features corresponding to the predictions, compatible with the `latents` argument of the pipeline. + The shape is $numimages * numensemble \times 4 \times latentheight \times latentwidth$. + """ + + prediction: Union[np.ndarray, ms.Tensor] + uncertainty: Union[None, np.ndarray, ms.Tensor] + latent: Union[None, ms.Tensor] + + +class MarigoldNormalsPipeline(DiffusionPipeline): + """ + Pipeline for monocular normals estimation using the Marigold method: https://marigoldmonodepth.github.io. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + unet (`UNet2DConditionModel`): + Conditional U-Net to denoise the normals latent, conditioned on image latent. + vae (`AutoencoderKL`): + Variational Auto-Encoder (VAE) Model to encode and decode images and predictions to and from latent + representations. + scheduler (`DDIMScheduler` or `LCMScheduler`): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. + text_encoder (`CLIPTextModel`): + Text-encoder, for empty text embedding. + tokenizer (`CLIPTokenizer`): + CLIP tokenizer. + prediction_type (`str`, *optional*): + Type of predictions made by the model. + use_full_z_range (`bool`, *optional*): + Whether the normals predicted by this model utilize the full range of the Z dimension, or only its positive + half. + default_denoising_steps (`int`, *optional*): + The minimum number of denoising diffusion steps that are required to produce a prediction of reasonable + quality with the given model. This value must be set in the model config. When the pipeline is called + without explicitly setting `num_inference_steps`, the default value is used. This is required to ensure + reasonable results with various model flavors compatible with the pipeline, such as those relying on very + short denoising schedules (`LCMScheduler`) and those with full diffusion schedules (`DDIMScheduler`). + default_processing_resolution (`int`, *optional*): + The recommended value of the `processing_resolution` parameter of the pipeline. This value must be set in + the model config. When the pipeline is called without explicitly setting `processing_resolution`, the + default value is used. This is required to ensure reasonable results with various model flavors trained + with varying optimal processing resolution values. + """ + + model_cpu_offload_seq = "text_encoder->unet->vae" + supported_prediction_types = ("normals",) + + def __init__( + self, + unet: UNet2DConditionModel, + vae: AutoencoderKL, + scheduler: Union[DDIMScheduler, LCMScheduler], + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + prediction_type: Optional[str] = None, + use_full_z_range: Optional[bool] = True, + default_denoising_steps: Optional[int] = None, + default_processing_resolution: Optional[int] = None, + ): + super().__init__() + + if prediction_type not in self.supported_prediction_types: + logger.warning( + f"Potentially unsupported `prediction_type='{prediction_type}'`; values supported by the pipeline: " + f"{self.supported_prediction_types}." + ) + + self.register_modules( + unet=unet, + vae=vae, + scheduler=scheduler, + text_encoder=text_encoder, + tokenizer=tokenizer, + ) + self.register_to_config( + use_full_z_range=use_full_z_range, + default_denoising_steps=default_denoising_steps, + default_processing_resolution=default_processing_resolution, + ) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + + self.use_full_z_range = use_full_z_range + self.default_denoising_steps = default_denoising_steps + self.default_processing_resolution = default_processing_resolution + + self.empty_text_embedding = None + + self.image_processor = MarigoldImageProcessor(vae_scale_factor=self.vae_scale_factor) + + def check_inputs( + self, + image: PipelineImageInput, + num_inference_steps: int, + ensemble_size: int, + processing_resolution: int, + resample_method_input: str, + resample_method_output: str, + batch_size: int, + ensembling_kwargs: Optional[Dict[str, Any]], + latents: Optional[ms.Tensor], + generator: Optional[Union[np.random.Generator, List[np.random.Generator]]], + output_type: str, + output_uncertainty: bool, + ) -> int: + if num_inference_steps is None: + raise ValueError("`num_inference_steps` is not specified and could not be resolved from the model config.") + if num_inference_steps < 1: + raise ValueError("`num_inference_steps` must be positive.") + if ensemble_size < 1: + raise ValueError("`ensemble_size` must be positive.") + if ensemble_size == 2: + logger.warning( + "`ensemble_size` == 2 results are similar to no ensembling (1); " + "consider increasing the value to at least 3." + ) + if ensemble_size == 1 and output_uncertainty: + raise ValueError( + "Computing uncertainty by setting `output_uncertainty=True` also requires setting `ensemble_size` " + "greater than 1." + ) + if processing_resolution is None: + raise ValueError( + "`processing_resolution` is not specified and could not be resolved from the model config." + ) + if processing_resolution < 0: + raise ValueError( + "`processing_resolution` must be non-negative: 0 for native resolution, or any positive value for " + "downsampled processing." + ) + if processing_resolution % self.vae_scale_factor != 0: + raise ValueError(f"`processing_resolution` must be a multiple of {self.vae_scale_factor}.") + if resample_method_input not in ("nearest", "nearest-exact", "bilinear", "bicubic", "area"): + raise ValueError( + "`resample_method_input` takes string values compatible with PIL library: " + "nearest, nearest-exact, bilinear, bicubic, area." + ) + if resample_method_output not in ("nearest", "nearest-exact", "bilinear", "bicubic", "area"): + raise ValueError( + "`resample_method_output` takes string values compatible with PIL library: " + "nearest, nearest-exact, bilinear, bicubic, area." + ) + if batch_size < 1: + raise ValueError("`batch_size` must be positive.") + if output_type not in ["pt", "np"]: + raise ValueError("`output_type` must be one of `pt` or `np`.") + if latents is not None and generator is not None: + raise ValueError("`latents` and `generator` cannot be used together.") + if ensembling_kwargs is not None: + if not isinstance(ensembling_kwargs, dict): + raise ValueError("`ensembling_kwargs` must be a dictionary.") + if "reduction" in ensembling_kwargs and ensembling_kwargs["reduction"] not in ("closest", "mean"): + raise ValueError("`ensembling_kwargs['reduction']` can be either `'closest'` or `'mean'`.") + + # image checks + num_images = 0 + W, H = None, None + if not isinstance(image, list): + image = [image] + for i, img in enumerate(image): + if isinstance(img, np.ndarray) or ops.is_tensor(img): + if img.ndim not in (2, 3, 4): + raise ValueError(f"`image[{i}]` has unsupported dimensions or shape: {img.shape}.") + H_i, W_i = img.shape[-2:] + N_i = 1 + if img.ndim == 4: + N_i = img.shape[0] + elif isinstance(img, Image.Image): + W_i, H_i = img.size + N_i = 1 + else: + raise ValueError(f"Unsupported `image[{i}]` type: {type(img)}.") + if W is None: + W, H = W_i, H_i + elif (W, H) != (W_i, H_i): + raise ValueError( + f"Input `image[{i}]` has incompatible dimensions {(W_i, H_i)} with the previous images {(W, H)}" + ) + num_images += N_i + + # latents checks + if latents is not None: + if not ops.is_tensor(latents): + raise ValueError("`latents` must be a ms.Tensor.") + if latents.ndim != 4: + raise ValueError(f"`latents` has unsupported dimensions or shape: {latents.shape}.") + + if processing_resolution > 0: + max_orig = max(H, W) + new_H = H * processing_resolution // max_orig + new_W = W * processing_resolution // max_orig + if new_H == 0 or new_W == 0: + raise ValueError(f"Extreme aspect ratio of the input image: [{W} x {H}]") + W, H = new_W, new_H + w = (W + self.vae_scale_factor - 1) // self.vae_scale_factor + h = (H + self.vae_scale_factor - 1) // self.vae_scale_factor + shape_expected = (num_images * ensemble_size, self.vae.config.latent_channels, h, w) + + if latents.shape != shape_expected: + raise ValueError(f"`latents` has unexpected shape={latents.shape} expected={shape_expected}.") + + # generator checks + if generator is not None: + if isinstance(generator, list): + if len(generator) != num_images * ensemble_size: + raise ValueError( + "The number of generators must match the total number of ensemble members for all input images." + ) + elif not isinstance(generator, np.random.Generator): + raise ValueError(f"Unsupported generator type: {type(generator)}.") + + return num_images + + def progress_bar(self, iterable=None, total=None, desc=None, leave=True): + if not hasattr(self, "_progress_bar_config"): + self._progress_bar_config = {} + elif not isinstance(self._progress_bar_config, dict): + raise ValueError( + f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." + ) + + progress_bar_config = dict(**self._progress_bar_config) + progress_bar_config["desc"] = progress_bar_config.get("desc", desc) + progress_bar_config["leave"] = progress_bar_config.get("leave", leave) + if iterable is not None: + return tqdm(iterable, **progress_bar_config) + elif total is not None: + return tqdm(total=total, **progress_bar_config) + else: + raise ValueError("Either `total` or `iterable` has to be defined.") + + def __call__( + self, + image: PipelineImageInput, + num_inference_steps: Optional[int] = None, + ensemble_size: int = 1, + processing_resolution: Optional[int] = None, + match_input_resolution: bool = True, + resample_method_input: str = "bilinear", + resample_method_output: str = "bilinear", + batch_size: int = 1, + ensembling_kwargs: Optional[Dict[str, Any]] = None, + latents: Optional[Union[ms.Tensor, List[ms.Tensor]]] = None, + generator: Optional[Union[np.random.Generator, List[np.random.Generator]]] = None, + output_type: str = "np", + output_uncertainty: bool = False, + output_latent: bool = False, + return_dict: bool = False, + ): + """ + Function invoked when calling the pipeline. + + Args: + image (`PIL.Image.Image`, `np.ndarray`, `ms.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`), + `List[ms.Tensor]`: An input image or images used as an input for the normals estimation task. For + arrays and tensors, the expected value range is between `[0, 1]`. Passing a batch of images is possible + by providing a four-dimensional array or a tensor. Additionally, a list of images of two- or + three-dimensional arrays or tensors can be passed. In the latter case, all list elements must have the + same width and height. + num_inference_steps (`int`, *optional*, defaults to `None`): + Number of denoising diffusion steps during inference. The default value `None` results in automatic + selection. The number of steps should be at least 10 with the full Marigold models, and between 1 and 4 + for Marigold-LCM models. + ensemble_size (`int`, defaults to `1`): + Number of ensemble predictions. Recommended values are 5 and higher for better precision, or 1 for + faster inference. + processing_resolution (`int`, *optional*, defaults to `None`): + Effective processing resolution. When set to `0`, matches the larger input image dimension. This + produces crisper predictions, but may also lead to the overall loss of global context. The default + value `None` resolves to the optimal value from the model config. + match_input_resolution (`bool`, *optional*, defaults to `True`): + When enabled, the output prediction is resized to match the input dimensions. When disabled, the longer + side of the output will equal to `processing_resolution`. + resample_method_input (`str`, *optional*, defaults to `"bilinear"`): + Resampling method used to resize input images to `processing_resolution`. The accepted values are: + `"nearest"`, `"nearest-exact"`, `"bilinear"`, `"bicubic"`, or `"area"`. + resample_method_output (`str`, *optional*, defaults to `"bilinear"`): + Resampling method used to resize output predictions to match the input resolution. The accepted values + are `"nearest"`, `"nearest-exact"`, `"bilinear"`, `"bicubic"`, or `"area"`. + batch_size (`int`, *optional*, defaults to `1`): + Batch size; only matters when setting `ensemble_size` or passing a tensor of images. + ensembling_kwargs (`dict`, *optional*, defaults to `None`) + Extra dictionary with arguments for precise ensembling control. The following options are available: + - reduction (`str`, *optional*, defaults to `"closest"`): Defines the ensembling function applied in + every pixel location, can be either `"closest"` or `"mean"`. + latents (`ms.Tensor`, *optional*, defaults to `None`): + Latent noise tensors to replace the random initialization. These can be taken from the previous + function call's output. + generator (`np.random.Generator`, or `List[np.random.Generator]`, *optional*, defaults to `None`): + Random number generator object to ensure reproducibility. + output_type (`str`, *optional*, defaults to `"np"`): + Preferred format of the output's `prediction` and the optional `uncertainty` fields. The accepted + values are: `"np"` (numpy array) or `"pt"` (torch tensor). + output_uncertainty (`bool`, *optional*, defaults to `False`): + When enabled, the output's `uncertainty` field contains the predictive uncertainty map, provided that + the `ensemble_size` argument is set to a value above 2. + output_latent (`bool`, *optional*, defaults to `False`): + When enabled, the output's `latent` field contains the latent codes corresponding to the predictions + within the ensemble. These codes can be saved, modified, and used for subsequent calls with the + `latents` argument. + return_dict (`bool`, *optional*, defaults to `False`): + Whether or not to return a [`~pipelines.marigold.MarigoldDepthOutput`] instead of a plain tuple. + + Examples: + + Returns: + [`~pipelines.marigold.MarigoldNormalsOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.marigold.MarigoldNormalsOutput`] is returned, otherwise a + `tuple` is returned where the first element is the prediction, the second element is the uncertainty + (or `None`), and the third is the latent (or `None`). + """ + + # 0. Resolving variables. + dtype = self.dtype + + # Model-specific optimal default values leading to fast and reasonable results. + if num_inference_steps is None: + num_inference_steps = self.default_denoising_steps + if processing_resolution is None: + processing_resolution = self.default_processing_resolution + + # 1. Check inputs. + num_images = self.check_inputs( + image, + num_inference_steps, + ensemble_size, + processing_resolution, + resample_method_input, + resample_method_output, + batch_size, + ensembling_kwargs, + latents, + generator, + output_type, + output_uncertainty, + ) + + # 2. Prepare empty text conditioning. + # Model invocation: self.tokenizer, self.text_encoder. + if self.empty_text_embedding is None: + prompt = "" + text_inputs = self.tokenizer( + prompt, + padding="do_not_pad", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="np", + ) + text_input_ids = ms.Tensor.from_numpy(text_inputs.input_ids) + self.empty_text_embedding = self.text_encoder(text_input_ids)[0] # [1,2,1024] + + # 3. Preprocess input images. This function loads input image or images of compatible dimensions `(H, W)`, + # optionally downsamples them to the `processing_resolution` `(PH, PW)`, where + # `max(PH, PW) == processing_resolution`, and pads the dimensions to `(PPH, PPW)` such that these values are + # divisible by the latent space downscaling factor (typically 8 in Stable Diffusion). The default value `None` + # of `processing_resolution` resolves to the optimal value from the model config. It is a recommended mode of + # operation and leads to the most reasonable results. Using the native image resolution or any other processing + # resolution can lead to loss of either fine details or global context in the output predictions. + image, padding, original_resolution = self.image_processor.preprocess( + image, processing_resolution, resample_method_input, dtype + ) # [N,3,PPH,PPW] + + # 4. Encode input image into latent space. At this step, each of the `N` input images is represented with `E` + # ensemble members. Each ensemble member is an independent diffused prediction, just initialized independently. + # Latents of each such predictions across all input images and all ensemble members are represented in the + # `pred_latent` variable. The variable `image_latent` is of the same shape: it contains each input image encoded + # into latent space and replicated `E` times. The latents can be either generated (see `generator` to ensure + # reproducibility), or passed explicitly via the `latents` argument. The latter can be set outside the pipeline + # code. For example, in the Marigold-LCM video processing demo, the latents initialization of a frame is taken + # as a convex combination of the latents output of the pipeline for the previous frame and a newly-sampled + # noise. This behavior can be achieved by setting the `output_latent` argument to `True`. The latent space + # dimensions are `(h, w)`. Encoding into latent space happens in batches of size `batch_size`. + # Model invocation: self.vae.encoder. + image_latent, pred_latent = self.prepare_latents( + image, latents, generator, ensemble_size, batch_size + ) # [N*E,4,h,w], [N*E,4,h,w] + + del image + + batch_empty_text_embedding = self.empty_text_embedding.to(dtype=dtype).tile((batch_size, 1, 1)) # [B,1024,2] + + # 5. Process the denoising loop. All `N * E` latents are processed sequentially in batches of size `batch_size`. + # The unet model takes concatenated latent spaces of the input image and the predicted modality as an input, and + # outputs noise for the predicted modality's latent space. The number of denoising diffusion steps is defined by + # `num_inference_steps`. It is either set directly, or resolves to the optimal value specific to the loaded + # model. + # Model invocation: self.unet. + pred_latents = [] + + for i in self.progress_bar( + range(0, num_images * ensemble_size, batch_size), leave=True, desc="Marigold predictions..." + ): + batch_image_latent = image_latent[i : i + batch_size] # [B,4,h,w] + batch_pred_latent = pred_latent[i : i + batch_size] # [B,4,h,w] + effective_batch_size = batch_image_latent.shape[0] + text = batch_empty_text_embedding[:effective_batch_size] # [B,2,1024] + + self.scheduler.set_timesteps(num_inference_steps) + for t in self.progress_bar(self.scheduler.timesteps, leave=False, desc="Diffusion steps..."): + batch_latent = ops.cat([batch_image_latent, batch_pred_latent], axis=1) # [B,8,h,w] + noise = self.unet(batch_latent, t, encoder_hidden_states=text, return_dict=False)[0] # [B,4,h,w] + batch_pred_latent = self.scheduler.step(noise, t, batch_pred_latent, generator=generator)[ + 0 + ] # [B,4,h,w] + + pred_latents.append(batch_pred_latent) + + pred_latent = ops.cat(pred_latents, axis=0) # [N*E,4,h,w] + + del ( + pred_latents, + image_latent, + batch_empty_text_embedding, + batch_image_latent, + batch_pred_latent, + text, + batch_latent, + noise, + ) + + # 6. Decode predictions from latent into pixel space. The resulting `N * E` predictions have shape `(PPH, PPW)`, + # which requires slight postprocessing. Decoding into pixel space happens in batches of size `batch_size`. + # Model invocation: self.vae.decoder. + prediction = ops.cat( + [ + self.decode_prediction(pred_latent[i : i + batch_size]) + for i in range(0, pred_latent.shape[0], batch_size) + ], + axis=0, + ) # [N*E,3,PPH,PPW] + + if not output_latent: + pred_latent = None + + # 7. Remove padding. The output shape is (PH, PW). + prediction = self.image_processor.unpad_image(prediction, padding) # [N*E,3,PH,PW] + + # 8. Ensemble and compute uncertainty (when `output_uncertainty` is set). This code treats each of the `N` + # groups of `E` ensemble predictions independently. For each group it computes an ensembled prediction of shape + # `(PH, PW)` and an optional uncertainty map of the same dimensions. After computing this pair of outputs for + # each group independently, it stacks them respectively into batches of `N` almost final predictions and + # uncertainty maps. + uncertainty = None + if ensemble_size > 1: + prediction = prediction.reshape(num_images, ensemble_size, *prediction.shape[1:]) # [N,E,3,PH,PW] + prediction = [ + self.ensemble_normals(prediction[i], output_uncertainty, **(ensembling_kwargs or {})) + for i in range(num_images) + ] # [ [[1,3,PH,PW], [1,1,PH,PW]], ... ] + prediction, uncertainty = zip(*prediction) # [[1,3,PH,PW], ... ], [[1,1,PH,PW], ... ] + prediction = ops.cat(prediction, axis=0) # [N,3,PH,PW] + if output_uncertainty: + uncertainty = ops.cat(uncertainty, axis=0) # [N,1,PH,PW] + else: + uncertainty = None + + # 9. If `match_input_resolution` is set, the output prediction and the uncertainty are upsampled to match the + # input resolution `(H, W)`. This step may introduce upsampling artifacts, and therefore can be disabled. + # After upsampling, the native resolution normal maps are renormalized to unit length to reduce the artifacts. + # Depending on the downstream use-case, upsampling can be also chosen based on the tolerated artifacts by + # setting the `resample_method_output` parameter (e.g., to `"nearest"`). + if match_input_resolution: + prediction = self.image_processor.resize_antialias( + prediction, original_resolution, resample_method_output, is_aa=False + ) # [N,3,H,W] + prediction = self.normalize_normals(prediction) # [N,3,H,W] + if uncertainty is not None and output_uncertainty: + uncertainty = self.image_processor.resize_antialias( + uncertainty, original_resolution, resample_method_output, is_aa=False + ) # [N,1,H,W] + + # 10. Prepare the final outputs. + if output_type == "np": + prediction = self.image_processor.pt_to_numpy(prediction) # [N,H,W,3] + if uncertainty is not None and output_uncertainty: + uncertainty = self.image_processor.pt_to_numpy(uncertainty) # [N,H,W,1] + + if not return_dict: + return (prediction, uncertainty, pred_latent) + + return MarigoldNormalsOutput( + prediction=prediction, + uncertainty=uncertainty, + latent=pred_latent, + ) + + # Copied from diffusers.pipelines.marigold.pipeline_marigold_depth.MarigoldDepthPipeline.prepare_latents + def prepare_latents( + self, + image: ms.Tensor, + latents: Optional[ms.Tensor], + generator: Optional[np.random.Generator], + ensemble_size: int, + batch_size: int, + ) -> Tuple[ms.Tensor, ms.Tensor]: + def retrieve_latents(encoder_output): + assert ops.is_tensor( + encoder_output + ), "Could not access latents of provided encoder_output which is not a tensor" + if hasattr(self.vae, "diag_gauss_dist"): + return self.vae.diag_gauss_dist.mode(encoder_output) + else: + return encoder_output + + image_latent = ops.cat( + [ + retrieve_latents(self.vae.encode(image[i : i + batch_size])[0]) + for i in range(0, image.shape[0], batch_size) + ], + axis=0, + ) # [N,4,h,w] + image_latent = image_latent * self.vae.config.scaling_factor + image_latent = image_latent.repeat_interleave(ensemble_size, dim=0) # [N*E,4,h,w] + + pred_latent = latents + if pred_latent is None: + pred_latent = randn_tensor( + image_latent.shape, + generator=generator, + dtype=image_latent.dtype, + ) # [N*E,4,h,w] + + return image_latent, pred_latent + + def decode_prediction(self, pred_latent: ms.Tensor) -> ms.Tensor: + if pred_latent.ndim != 4 or pred_latent.shape[1] != self.vae.config.latent_channels: + raise ValueError( + f"Expecting 4D tensor of shape [B,{self.vae.config.latent_channels},H,W]; got {pred_latent.shape}." + ) + + prediction = self.vae.decode(pred_latent / self.vae.config.scaling_factor, return_dict=False)[0] # [B,3,H,W] + + prediction = ops.clip(prediction, -1.0, 1.0) + + if not self.use_full_z_range: + prediction[:, 2, :, :] *= 0.5 + prediction[:, 2, :, :] += 0.5 + + prediction = self.normalize_normals(prediction) # [B,3,H,W] + + return prediction # [B,3,H,W] + + @staticmethod + def normalize_normals(normals: ms.Tensor, eps: float = 1e-6) -> ms.Tensor: + if normals.ndim != 4 or normals.shape[1] != 3: + raise ValueError(f"Expecting 4D tensor of shape [B,3,H,W]; got {normals.shape}.") + + norm = ops.norm(normals, dim=1, keepdim=True) + normals /= norm.clamp(min=eps) + + return normals + + @staticmethod + def ensemble_normals( + normals: ms.Tensor, output_uncertainty: bool, reduction: str = "closest" + ) -> Tuple[ms.Tensor, Optional[ms.Tensor]]: + """ + Ensembles the normals maps represented by the `normals` tensor with expected shape `(B, 3, H, W)`, where B is + the number of ensemble members for a given prediction of size `(H x W)`. + + Args: + normals (`ms.Tensor`): + Input ensemble normals maps. + output_uncertainty (`bool`, *optional*, defaults to `False`): + Whether to output uncertainty map. + reduction (`str`, *optional*, defaults to `"closest"`): + Reduction method used to ensemble aligned predictions. The accepted values are: `"closest"` and + `"mean"`. + + Returns: + A tensor of aligned and ensembled normals maps with shape `(1, 3, H, W)` and optionally a tensor of + uncertainties of shape `(1, 1, H, W)`. + """ + if normals.ndim != 4 or normals.shape[1] != 3: + raise ValueError(f"Expecting 4D tensor of shape [B,3,H,W]; got {normals.shape}.") + if reduction not in ("closest", "mean"): + raise ValueError(f"Unrecognized reduction method: {reduction}.") + + mean_normals = normals.mean(axis=0, keep_dims=True) # [1,3,H,W] + mean_normals = MarigoldNormalsPipeline.normalize_normals(mean_normals) # [1,3,H,W] + + sim_cos = (mean_normals * normals).sum(axis=1, keepdims=True) # [E,1,H,W] + sim_cos = sim_cos.clamp(-1.0, 1.0) # required to avoid NaN in uncertainty with fp16 + + uncertainty = None + if output_uncertainty: + uncertainty = sim_cos.arccos() # [E,1,H,W] + uncertainty = uncertainty.mean(axis=0, keep_dims=True) / ms.numpy.pi # [1,1,H,W] + + if reduction == "mean": + return mean_normals, uncertainty # [1,3,H,W], [1,1,H,W] + + closest_indices = sim_cos.argmax(axis=0, keepdims=True) # [1,1,H,W] + closest_indices = closest_indices.tile((1, 3, 1, 1)) # [1,3,H,W] + closest_normals = ops.gather_elements(normals, 0, closest_indices) # [1,3,H,W] + + return closest_normals, uncertainty # [1,3,H,W], [1,1,H,W] diff --git a/mindone/diffusers/pipelines/pipeline_loading_utils.py b/mindone/diffusers/pipelines/pipeline_loading_utils.py index 7a1a93b374..ee9675f351 100644 --- a/mindone/diffusers/pipelines/pipeline_loading_utils.py +++ b/mindone/diffusers/pipelines/pipeline_loading_utils.py @@ -302,6 +302,39 @@ def get_class_obj_and_candidates( return class_obj, class_candidates +def _get_custom_pipeline_class( + custom_pipeline, + repo_id=None, + hub_revision=None, + class_name=None, + cache_dir=None, + revision=None, +): + if custom_pipeline.endswith(".py"): + path = Path(custom_pipeline) + # decompose into folder & file + file_name = path.name + custom_pipeline = path.parent.absolute() + elif repo_id is not None: + file_name = f"{custom_pipeline}.py" + custom_pipeline = repo_id + else: + file_name = CUSTOM_PIPELINE_FILE_NAME + + if repo_id is not None and hub_revision is not None: + # if we load the pipeline code from the Hub + # make sure to overwrite the `revision` + revision = hub_revision + + return get_class_from_dynamic_module( + custom_pipeline, + module_file=file_name, + class_name=class_name, + cache_dir=cache_dir, + revision=revision, + ) + + def _get_pipeline_class( class_obj, config=None, @@ -314,25 +347,10 @@ def _get_pipeline_class( revision=None, ): if custom_pipeline is not None: - if custom_pipeline.endswith(".py"): - path = Path(custom_pipeline) - # decompose into folder & file - file_name = path.name - custom_pipeline = path.parent.absolute() - elif repo_id is not None: - file_name = f"{custom_pipeline}.py" - custom_pipeline = repo_id - else: - file_name = CUSTOM_PIPELINE_FILE_NAME - - if repo_id is not None and hub_revision is not None: - # if we load the pipeline code from the Hub - # make sure to overwrite the `revision` - revision = hub_revision - - return get_class_from_dynamic_module( + return _get_custom_pipeline_class( custom_pipeline, - module_file=file_name, + repo_id=repo_id, + hub_revision=hub_revision, class_name=class_name, cache_dir=cache_dir, revision=revision, @@ -372,7 +390,9 @@ def load_sub_model( cached_folder: Union[str, os.PathLike], ): """Helper method to load the module `name` from `library_name` and `class_name`""" + # retrieve class candidates + class_obj, class_candidates = get_class_obj_and_candidates( library_name, class_name, diff --git a/mindone/diffusers/pipelines/pipeline_utils.py b/mindone/diffusers/pipelines/pipeline_utils.py index 205531bdf9..b1d414b6a9 100644 --- a/mindone/diffusers/pipelines/pipeline_utils.py +++ b/mindone/diffusers/pipelines/pipeline_utils.py @@ -19,7 +19,7 @@ import re from dataclasses import dataclass from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union, get_args, get_origin import numpy as np import PIL.Image @@ -35,6 +35,7 @@ from .. import __version__ from ..configuration_utils import ConfigMixin +from ..models.modeling_utils import ModelMixin from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME from ..utils import ( CONFIG_NAME, @@ -53,6 +54,7 @@ CUSTOM_PIPELINE_FILE_NAME, LOADABLE_CLASSES, _fetch_class_library_tuple, + _get_custom_pipeline_class, _get_pipeline_class, is_safetensors_compatible, load_sub_model, @@ -362,9 +364,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory where a downloaded pretrained model configuration is cached if the standard cache is not used. - resume_download (`bool`, *optional*, defaults to `False`): - Whether or not to resume downloading the model weights and configuration files. If set to `False`, any - incompletely downloaded files are deleted. + resume_download: + Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1 + of Diffusers. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. @@ -381,7 +383,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P allowed by Git. custom_revision (`str`, *optional*): The specific model version to use. It can be a branch name, a tag name, or a commit id similar to - `revision` when loading a custom pipeline from the Hub. Defaults to the latest stable 🤗 Diffusers version. + `revision` when loading a custom pipeline from the Hub. Defaults to the latest stable 🤗 Diffusers + version. mirror (`str`, *optional*): Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not guarantee the timeliness or safety of the source, and you should refer to the mirror site for more @@ -431,7 +434,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P ``` """ cache_dir = kwargs.pop("cache_dir", None) - resume_download = kwargs.pop("resume_download", False) + resume_download = kwargs.pop("resume_download", None) force_download = kwargs.pop("force_download", False) proxies = kwargs.pop("proxies", None) local_files_only = kwargs.pop("local_files_only", None) @@ -550,17 +553,18 @@ def load_module(name, value): # import it here to avoid circular import from mindone.diffusers import pipelines - # 6. Load each module in the pipeline + # 6. device map delegation which is not supported in MindSpore + # 7. Load each module in the pipeline for name, (library_name, class_name) in logging.tqdm(init_dict.items(), desc="Loading pipeline components..."): - # 6.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names + # 7.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names class_name = class_name[4:] if class_name.startswith("Flax") else class_name - # 6.2 Define all importable classes + # 7.2 Define all importable classes is_pipeline_module = hasattr(pipelines, library_name) importable_classes = ALL_IMPORTABLE_CLASSES loaded_sub_model = None - # 6.3 Use passed sub model or load class_name from library_name + # 7.3 Use passed sub model or load class_name from library_name if name in passed_class_obj: # if the model is in a pipeline module, then we load it from the pipeline # check that passed_class_obj has correct parent class @@ -633,7 +637,7 @@ def get_connected_passed_kwargs(prefix): {"_".join([prefix, name]): component for name, component in connected_pipe.components.items()} ) - # 7. Potentially add passed objects if expected + # 8. Potentially add passed objects if expected missing_modules = set(expected_modules) - set(init_kwargs.keys()) passed_modules = list(passed_class_obj.keys()) optional_modules = pipeline_class._optional_components @@ -646,10 +650,10 @@ def get_connected_passed_kwargs(prefix): f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed." ) - # 8. Instantiate the pipeline + # 9. Instantiate the pipeline model = pipeline_class(**init_kwargs) - # 9. Save where the model was instantiated from + # 10. Save where the model was instantiated from model.register_to_config(_name_or_path=pretrained_model_name_or_path) return model @@ -695,9 +699,9 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. - resume_download (`bool`, *optional*, defaults to `False`): - Whether or not to resume downloading the model weights and configuration files. If set to `False`, any - incompletely downloaded files are deleted. + resume_download: + Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1 + of Diffusers. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. @@ -750,7 +754,7 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: """ cache_dir = kwargs.pop("cache_dir", None) - resume_download = kwargs.pop("resume_download", False) + resume_download = kwargs.pop("resume_download", None) force_download = kwargs.pop("force_download", False) proxies = kwargs.pop("proxies", None) local_files_only = kwargs.pop("local_files_only", None) @@ -954,7 +958,6 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: # Don't download index files of forbidden patterns either ignore_patterns = ignore_patterns + [f"{i}.index.*json" for i in ignore_patterns] - re_ignore_pattern = [re.compile(fnmatch.translate(p)) for p in ignore_patterns] re_allow_pattern = [re.compile(fnmatch.translate(p)) for p in allow_patterns] @@ -1044,6 +1047,18 @@ def _get_signature_keys(cls, obj): return expected_modules, optional_parameters + @classmethod + def _get_signature_types(cls): + signature_types = {} + for k, v in inspect.signature(cls.__init__).parameters.items(): + if inspect.isclass(v.annotation): + signature_types[k] = (v.annotation,) + elif get_origin(v.annotation) == Union: + signature_types[k] = get_args(v.annotation) + else: + logger.warning(f"cannot get type annotation for Parameter {k} of {cls}.") + return signature_types + @property def components(self) -> Dict[str, Any]: r""" @@ -1159,3 +1174,126 @@ def fn_recursive_set_mem_eff(module: nn.Cell): for module in modules: fn_recursive_set_mem_eff(module) + + @classmethod + def from_pipe(cls, pipeline, **kwargs): + r""" + Create a new pipeline from a given pipeline. This method is useful to create a new pipeline from the existing + pipeline components without reallocating additional memory. + + Arguments: + pipeline (`DiffusionPipeline`): + The pipeline from which to create a new pipeline. + + Returns: + `DiffusionPipeline`: + A new pipeline with the same weights and configurations as `pipeline`. + + Examples: + + ```py + >>> from mindone.diffusers import StableDiffusionPipeline, StableDiffusionSAGPipeline + + >>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") + >>> new_pipe = StableDiffusionSAGPipeline.from_pipe(pipe) + ``` + """ + + original_config = dict(pipeline.config) + mindspore_dtype = kwargs.pop("mindspore_dtype", None) + + # derive the pipeline class to instantiate + custom_pipeline = kwargs.pop("custom_pipeline", None) + custom_revision = kwargs.pop("custom_revision", None) + + if custom_pipeline is not None: + pipeline_class = _get_custom_pipeline_class(custom_pipeline, revision=custom_revision) + else: + pipeline_class = cls + + expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class) + # true_optional_modules are optional components with default value in signature so it is ok not to pass them to `__init__` + # e.g. `image_encoder` for StableDiffusionPipeline + parameters = inspect.signature(cls.__init__).parameters + true_optional_modules = set( + {k for k, v in parameters.items() if v.default != inspect._empty and k in expected_modules} + ) + + # get the class of each component based on its type hint + # e.g. {"unet": UNet2DConditionModel, "text_encoder": CLIPTextMode} + component_types = pipeline_class._get_signature_types() + + pretrained_model_name_or_path = original_config.pop("_name_or_path", None) + # allow users pass modules in `kwargs` to override the original pipeline's components + passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs} + + original_class_obj = {} + for name, component in pipeline.components.items(): + if name in expected_modules and name not in passed_class_obj: + # for model components, we will not switch over if the class does not matches the type hint in the new pipeline's signature + if ( + not isinstance(component, ModelMixin) + or type(component) in component_types[name] + or (component is None and name in cls._optional_components) + ): + original_class_obj[name] = component + else: + logger.warning( + f"component {name} is not switched over to new pipeline because type does not match the expected." + f" {name} is {type(component)} while the new pipeline expect {component_types[name]}." + f" please pass the component of the correct type to the new pipeline. `from_pipe(..., {name}={name})`" + ) + + # allow users pass optional kwargs to override the original pipelines config attribute + passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs} + original_pipe_kwargs = { + k: original_config[k] + for k in original_config.keys() + if k in optional_kwargs and k not in passed_pipe_kwargs + } + + # config attribute that were not expected by pipeline is stored as its private attribute + # (i.e. when the original pipeline was also instantiated with `from_pipe` from another pipeline that has this config) + # in this case, we will pass them as optional arguments if they can be accepted by the new pipeline + additional_pipe_kwargs = [ + k[1:] + for k in original_config.keys() + if k.startswith("_") and k[1:] in optional_kwargs and k[1:] not in passed_pipe_kwargs + ] + for k in additional_pipe_kwargs: + original_pipe_kwargs[k] = original_config.pop(f"_{k}") + + pipeline_kwargs = { + **passed_class_obj, + **original_class_obj, + **passed_pipe_kwargs, + **original_pipe_kwargs, + **kwargs, + } + + # store unused config as private attribute in the new pipeline + unused_original_config = { + f"{'' if k.startswith('_') else '_'}{k}": v for k, v in original_config.items() if k not in pipeline_kwargs + } + + missing_modules = ( + set(expected_modules) + - set(pipeline._optional_components) + - set(pipeline_kwargs.keys()) + - set(true_optional_modules) + ) + + if len(missing_modules) > 0: + raise ValueError( + f"Pipeline {pipeline_class} expected {expected_modules}, but only {set(list(passed_class_obj.keys()) + list(original_class_obj.keys()))} were passed" # noqa: E501 + ) + + new_pipeline = pipeline_class(**pipeline_kwargs) + if pretrained_model_name_or_path is not None: + new_pipeline.register_to_config(_name_or_path=pretrained_model_name_or_path) + new_pipeline.register_to_config(**unused_original_config) + + if mindspore_dtype is not None: + new_pipeline.to(dtype=mindspore_dtype) + + return new_pipeline diff --git a/mindone/diffusers/pipelines/pixart_alpha/__init__.py b/mindone/diffusers/pipelines/pixart_alpha/__init__.py index faa2497e32..e57abb7225 100644 --- a/mindone/diffusers/pipelines/pixart_alpha/__init__.py +++ b/mindone/diffusers/pipelines/pixart_alpha/__init__.py @@ -4,6 +4,7 @@ _import_structure = {} _import_structure["pipeline_pixart_alpha"] = ["PixArtAlphaPipeline"] +_import_structure["pipeline_pixart_sigma"] = ["PixArtSigmaPipeline"] if TYPE_CHECKING: from .pipeline_pixart_alpha import ( @@ -12,6 +13,7 @@ ASPECT_RATIO_1024_BIN, PixArtAlphaPipeline, ) + from .pipeline_pixart_sigma import ASPECT_RATIO_2048_BIN, PixArtSigmaPipeline else: import sys diff --git a/mindone/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py b/mindone/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py index cd12631346..0cb63f496b 100644 --- a/mindone/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +++ b/mindone/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py @@ -24,8 +24,8 @@ import mindspore as ms from mindspore import ops -from ...image_processor import VaeImageProcessor -from ...models import AutoencoderKL, Transformer2DModel +from ...image_processor import PixArtImageProcessor +from ...models import AutoencoderKL, PixArtTransformer2DModel from ...schedulers import DPMSolverMultistepScheduler from ...utils import BACKENDS_MAPPING, deprecate, is_bs4_available, is_ftfy_available, logging from ...utils.mindspore_utils import randn_tensor @@ -167,6 +167,7 @@ def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, **kwargs, ): """ @@ -177,17 +178,21 @@ def retrieve_timesteps( scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): - The number of diffusion steps used when generating samples with a pre-trained model. If used, - `timesteps` must be `None`. + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. timesteps (`List[int]`, *optional*): - Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default - timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` - must be `None`. + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[ms.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: @@ -198,6 +203,16 @@ def retrieve_timesteps( scheduler.set_timesteps(timesteps=timesteps, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, **kwargs) timesteps = scheduler.timesteps @@ -221,8 +236,8 @@ class PixArtAlphaPipeline(DiffusionPipeline): tokenizer (`T5Tokenizer`): Tokenizer of class [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). - transformer ([`Transformer2DModel`]): - A text conditioned `Transformer2DModel` to denoise the encoded image latents. + transformer ([`PixArtTransformer2DModel`]): + A text conditioned `PixArtTransformer2DModel` to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `transformer` to denoise the encoded image latents. """ @@ -251,7 +266,7 @@ def __init__( tokenizer: T5Tokenizer, text_encoder: T5EncoderModel, vae: AutoencoderKL, - transformer: Transformer2DModel, + transformer: PixArtTransformer2DModel, scheduler: DPMSolverMultistepScheduler, ): super().__init__() @@ -261,16 +276,7 @@ def __init__( ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) - - # Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/utils.py - def mask_text_embeddings(self, emb, mask): - if emb.shape[0] == 1: - keep_index = mask.sum().item() - return emb[:, :, :keep_index, :], keep_index - else: - masked_feature = emb * mask[:, None, :, None] - return masked_feature, emb.shape[2] + self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor) def encode_prompt( self, @@ -342,7 +348,7 @@ def encode_prompt( if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not ops.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" + "The following part of your input was truncated because T5 can only handle sequences up to" f" {max_length} tokens: {removed_text}" ) @@ -369,7 +375,7 @@ def encode_prompt( # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: - uncond_tokens = [negative_prompt] * batch_size + uncond_tokens = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( @@ -630,7 +636,12 @@ def _clean_caption(self, caption): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, generator, latents=None): - shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" @@ -648,44 +659,13 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype latents = latents.to(dtype=dtype) return latents - @staticmethod - def classify_height_width_bin(height: int, width: int, ratios: dict) -> Tuple[int, int]: - """Returns binned height and width.""" - ar = float(height / width) - closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar)) - default_hw = ratios[closest_ratio] - return int(default_hw[0]), int(default_hw[1]) - - @staticmethod - def resize_and_crop_tensor(samples: ms.Tensor, new_width: int, new_height: int) -> ms.Tensor: - orig_height, orig_width = samples.shape[2], samples.shape[3] - - # Check if resizing is needed - if orig_height != new_height or orig_width != new_width: - ratio = max(new_height / orig_height, new_width / orig_width) - resized_width = int(orig_width * ratio) - resized_height = int(orig_height * ratio) - - # Resize - samples = ops.interpolate( - samples, size=(resized_height, resized_width), mode="bilinear", align_corners=False - ) - - # Center Crop - start_x = (resized_width - new_width) // 2 - end_x = start_x + new_width - start_y = (resized_height - new_height) // 2 - end_y = start_y + new_height - samples = samples[:, :, start_y:end_y, start_x:end_x] - - return samples - def __call__( self, prompt: Union[str, List[str]] = None, negative_prompt: str = "", num_inference_steps: int = 20, timesteps: List[int] = None, + sigmas: List[float] = None, guidance_scale: float = 4.5, num_images_per_prompt: Optional[int] = 1, height: Optional[int] = None, @@ -721,8 +701,13 @@ def __call__( The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. timesteps (`List[int]`, *optional*): - Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` - timesteps are used. Must be in descending order. + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. guidance_scale (`float`, *optional*, defaults to 4.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen @@ -799,7 +784,7 @@ def __call__( else: raise ValueError("Invalid sample size") orig_height, orig_width = height, width - height, width = self.classify_height_width_bin(height, width, ratios=aspect_ratio_bin) + height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin) self.check_inputs( prompt, @@ -849,7 +834,7 @@ def __call__( prompt_attention_mask = ops.cat([negative_prompt_attention_mask, prompt_attention_mask], axis=0) # 4. Prepare timesteps - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, timesteps) + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, timesteps, sigmas) # 5. Prepare latents. latent_channels = self.transformer.config.in_channels @@ -926,7 +911,13 @@ def __call__( # compute previous image: x_t -> x_t-1 latents_dtype = latents.dtype - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + if num_inference_steps == 1: + # For DMD one step sampling: https://arxiv.org/abs/2311.18828 + latents = self.scheduler.step( + noise_pred, t, latents, **extra_step_kwargs, return_dict=True + ).pred_original_sample + else: + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] latents = latents.to(latents_dtype) # call the callback, if provided @@ -939,7 +930,7 @@ def __call__( if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] if use_resolution_binning: - image = self.resize_and_crop_tensor(image, orig_width, orig_height) + image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height) else: image = latents diff --git a/mindone/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py b/mindone/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py new file mode 100644 index 0000000000..7569e3563e --- /dev/null +++ b/mindone/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py @@ -0,0 +1,846 @@ +# Copyright 2024 PixArt-Sigma Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +import inspect +import re +import urllib.parse as ul +from typing import Callable, List, Optional, Tuple, Union + +import numpy as np +from transformers import T5Tokenizer + +import mindspore as ms +from mindspore import ops + +from ....transformers import T5EncoderModel +from ...image_processor import PixArtImageProcessor +from ...models import AutoencoderKL, PixArtTransformer2DModel +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import BACKENDS_MAPPING, deprecate, is_bs4_available, is_ftfy_available, logging +from ...utils.mindspore_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from .pipeline_pixart_alpha import ASPECT_RATIO_256_BIN, ASPECT_RATIO_512_BIN, ASPECT_RATIO_1024_BIN + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + + +ASPECT_RATIO_2048_BIN = { + "0.25": [1024.0, 4096.0], + "0.26": [1024.0, 3968.0], + "0.27": [1024.0, 3840.0], + "0.28": [1024.0, 3712.0], + "0.32": [1152.0, 3584.0], + "0.33": [1152.0, 3456.0], + "0.35": [1152.0, 3328.0], + "0.4": [1280.0, 3200.0], + "0.42": [1280.0, 3072.0], + "0.48": [1408.0, 2944.0], + "0.5": [1408.0, 2816.0], + "0.52": [1408.0, 2688.0], + "0.57": [1536.0, 2688.0], + "0.6": [1536.0, 2560.0], + "0.68": [1664.0, 2432.0], + "0.72": [1664.0, 2304.0], + "0.78": [1792.0, 2304.0], + "0.82": [1792.0, 2176.0], + "0.88": [1920.0, 2176.0], + "0.94": [1920.0, 2048.0], + "1.0": [2048.0, 2048.0], + "1.07": [2048.0, 1920.0], + "1.13": [2176.0, 1920.0], + "1.21": [2176.0, 1792.0], + "1.29": [2304.0, 1792.0], + "1.38": [2304.0, 1664.0], + "1.46": [2432.0, 1664.0], + "1.67": [2560.0, 1536.0], + "1.75": [2688.0, 1536.0], + "2.0": [2816.0, 1408.0], + "2.09": [2944.0, 1408.0], + "2.4": [3072.0, 1280.0], + "2.5": [3200.0, 1280.0], + "2.89": [3328.0, 1152.0], + "3.0": [3456.0, 1152.0], + "3.11": [3584.0, 1152.0], + "3.62": [3712.0, 1024.0], + "3.75": [3840.0, 1024.0], + "3.88": [3968.0, 1024.0], + "4.0": [4096.0, 1024.0], +} + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import mindspore + >>> from mindone.diffusers import PixArtSigmaPipeline + + >>> # You can replace the checkpoint id with "PixArt-alpha/PixArt-Sigma-XL-2-512-MS" too. + >>> pipe = PixArtSigmaPipeline.from_pretrained( + ... "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS", mindspore_dtype=mindspore.float16 + ... ) + + >>> prompt = "A small cactus with a happy face in the Sahara desert." + >>> image = pipe(prompt)[0][0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[ms.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and + the second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class PixArtSigmaPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using PixArt-Sigma. + """ + + bad_punct_regex = re.compile( + r"[" + + "#®•©™&@·º½¾¿¡§~" + + r"\)" + + r"\(" + + r"\]" + + r"\[" + + r"\}" + + r"\{" + + r"\|" + + "\\" + + r"\/" + + r"\*" + + r"]{1,}" + ) # noqa + + _optional_components = ["tokenizer", "text_encoder"] + model_cpu_offload_seq = "text_encoder->transformer->vae" + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + vae: AutoencoderKL, + transformer: PixArtTransformer2DModel, + scheduler: KarrasDiffusionSchedulers, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor) + + # Copied from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha.PixArtAlphaPipeline.encode_prompt with 120->300 + def encode_prompt( + self, + prompt: Union[str, List[str]], + do_classifier_free_guidance: bool = True, + negative_prompt: str = "", + num_images_per_prompt: int = 1, + prompt_embeds: Optional[ms.Tensor] = None, + negative_prompt_embeds: Optional[ms.Tensor] = None, + prompt_attention_mask: Optional[ms.Tensor] = None, + negative_prompt_attention_mask: Optional[ms.Tensor] = None, + clean_caption: bool = False, + max_sequence_length: int = 300, + **kwargs, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` + instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For + PixArt-Alpha, this should be "". + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + prompt_embeds (`ms.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`ms.Tensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Alpha, it's should be the embeddings of the "" + string. + clean_caption (`bool`, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + max_sequence_length (`int`, defaults to 120): Maximum sequence length to use for the prompt. + """ + + if "mask_feature" in kwargs: + deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation \ + and that doesn't affect the end results. It will be removed in a future version." + deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # See Section 3.1. of the paper. + max_length = max_sequence_length + + if prompt_embeds is None: + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + add_special_tokens=True, + return_tensors="np", + ) + text_input_ids = ms.Tensor.from_numpy(text_inputs.input_ids) + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not ops.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {max_length} tokens: {removed_text}" + ) + + prompt_attention_mask = ms.Tensor.from_numpy(text_inputs.attention_mask) + + prompt_embeds = self.text_encoder(text_input_ids, attention_mask=prompt_attention_mask) + prompt_embeds = prompt_embeds[0] + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + elif self.transformer is not None: + dtype = self.transformer.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.tile((1, num_images_per_prompt, 1)) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1) + prompt_attention_mask = prompt_attention_mask.tile((num_images_per_prompt, 1)) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens = [negative_prompt] * batch_size + uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="np", + ) + negative_prompt_attention_mask = ms.Tensor.from_numpy(uncond_input.attention_mask) + + negative_prompt_embeds = self.text_encoder( + ms.Tensor.from_numpy(uncond_input.input_ids), attention_mask=negative_prompt_attention_mask + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype) + + negative_prompt_embeds = negative_prompt_embeds.tile((1, num_images_per_prompt, 1)) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1) + negative_prompt_attention_mask = negative_prompt_attention_mask.tile((num_images_per_prompt, 1)) + else: + negative_prompt_embeds = None + negative_prompt_attention_mask = None + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha.PixArtAlphaPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_steps, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing + def _text_preprocessing(self, text, clean_caption=False): + if clean_caption and not is_bs4_available(): + logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if clean_caption and not is_ftfy_available(): + logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + if clean_caption: + text = self._clean_caption(text) + text = self._clean_caption(text) + else: + text = text.lower().strip() + return text + + return [process(t) for t in text] + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption + def _clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip adresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = ftfy.fix_text(caption) + caption = html.unescape(html.unescape(caption)) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, dtype=dtype) + else: + latents = latents.to(dtype) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + # wtf? The above line changes the dtype of latents from fp16 to fp32, so we need a casting. + latents = latents.to(dtype=dtype) + return latents + + def __call__( + self, + prompt: Union[str, List[str]] = None, + negative_prompt: str = "", + num_inference_steps: int = 20, + timesteps: List[int] = None, + sigmas: List[float] = None, + guidance_scale: float = 4.5, + num_images_per_prompt: Optional[int] = 1, + height: Optional[int] = None, + width: Optional[int] = None, + eta: float = 0.0, + generator: Optional[Union[np.random.Generator, List[np.random.Generator]]] = None, + latents: Optional[ms.Tensor] = None, + prompt_embeds: Optional[ms.Tensor] = None, + prompt_attention_mask: Optional[ms.Tensor] = None, + negative_prompt_embeds: Optional[ms.Tensor] = None, + negative_prompt_attention_mask: Optional[ms.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = False, + callback: Optional[Callable[[int, int, ms.Tensor], None]] = None, + callback_steps: int = 1, + clean_caption: bool = True, + use_resolution_binning: bool = True, + max_sequence_length: int = 300, + **kwargs, + ) -> Union[ImagePipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 4.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + height (`int`, *optional*, defaults to self.unet.config.sample_size): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size): + The width in pixels of the generated image. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`np.random.Generator` or `List[np.random.Generator]`, *optional*): + One or a list of [numpy generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`ms.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`ms.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`ms.Tensor`, *optional*): Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`ms.Tensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + negative_prompt_attention_mask (`ms.Tensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `False`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: ms.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + use_resolution_binning (`bool` defaults to `True`): + If set to `True`, the requested height and width are first mapped to the closest resolutions using + `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to + the requested resolution. Useful for generating non-square images. + max_sequence_length (`int` defaults to 300): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images + """ + # 1. Check inputs. Raise error if not correct + height = height or self.transformer.config.sample_size * self.vae_scale_factor + width = width or self.transformer.config.sample_size * self.vae_scale_factor + if use_resolution_binning: + if self.transformer.config.sample_size == 256: + aspect_ratio_bin = ASPECT_RATIO_2048_BIN + elif self.transformer.config.sample_size == 128: + aspect_ratio_bin = ASPECT_RATIO_1024_BIN + elif self.transformer.config.sample_size == 64: + aspect_ratio_bin = ASPECT_RATIO_512_BIN + elif self.transformer.config.sample_size == 32: + aspect_ratio_bin = ASPECT_RATIO_256_BIN + else: + raise ValueError("Invalid sample size") + orig_height, orig_width = height, width + height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin) + + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_steps, + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + ) + + # 2. Default height and width to transformer + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt, + do_classifier_free_guidance, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + ) + if do_classifier_free_guidance: + prompt_embeds = ops.cat([negative_prompt_embeds, prompt_embeds], axis=0) + prompt_attention_mask = ops.cat([negative_prompt_attention_mask, prompt_attention_mask], axis=0) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, timesteps, sigmas) + + # 5. Prepare latents. + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + latent_channels, + height, + width, + prompt_embeds.dtype, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 6.1 Prepare micro-conditions. + added_cond_kwargs = {"resolution": None, "aspect_ratio": None} + + # 7. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + latent_model_input = ops.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + current_timestep = t + if not ops.is_tensor(current_timestep): + if isinstance(current_timestep, float): + dtype = ms.float64 + else: + dtype = ms.int64 + current_timestep = ms.Tensor([current_timestep], dtype=dtype) + elif len(current_timestep.shape) == 0: + current_timestep = current_timestep[None] + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + current_timestep = current_timestep.broadcast_to((latent_model_input.shape[0],)) + + # predict noise model_output + noise_pred = self.transformer( + latent_model_input, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_attention_mask, + timestep=current_timestep, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # learned sigma + if self.transformer.config.out_channels // 2 == latent_channels: + noise_pred = noise_pred.chunk(2, axis=1)[0] + else: + noise_pred = noise_pred + + # compute previous image: x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + if use_resolution_binning: + image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height) + else: + image = latents + + if not output_type == "latent": + image = self.image_processor.postprocess(image, output_type=output_type) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/mindone/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py b/mindone/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py index 25adb0c7c8..1e8df3ede4 100644 --- a/mindone/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +++ b/mindone/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py @@ -232,14 +232,14 @@ def __call__( num_embeddings = self.prior.config.num_embeddings embedding_dim = self.prior.config.embedding_dim - - latents = self.prepare_latents( - (batch_size, num_embeddings * embedding_dim), - image_embeds.dtype, - generator, - latents, - self.scheduler, - ) + if latents is None: + latents = self.prepare_latents( + (batch_size, num_embeddings * embedding_dim), + image_embeds.dtype, + generator, + latents, + self.scheduler, + ) # YiYi notes: for testing only to match ldm, we can directly create a latents with desired shape: batch_size, num_embeddings, embedding_dim latents = latents.reshape(latents.shape[0], num_embeddings, embedding_dim) diff --git a/mindone/diffusers/pipelines/shap_e/renderer.py b/mindone/diffusers/pipelines/shap_e/renderer.py index 9af6839bd3..93c14afc4b 100644 --- a/mindone/diffusers/pipelines/shap_e/renderer.py +++ b/mindone/diffusers/pipelines/shap_e/renderer.py @@ -516,8 +516,8 @@ class MeshDecoder(nn.Cell): def __init__(self): super().__init__() - self.cases = ops.zeros((256, 5, 3), dtype=ms.int64) - self.masks = ops.zeros((256, 5), dtype=ms.bool_) + self.cases = ms.Parameter(ops.zeros((256, 5, 3), dtype=ms.int64), name="cases") + self.masks = ms.Parameter(ops.zeros((256, 5), dtype=ms.bool_), name="masks") def construct(self, field: ms.Tensor, min_point: ms.Tensor, size: ms.Tensor): """ @@ -532,8 +532,12 @@ def construct(self, field: ms.Tensor, min_point: ms.Tensor, size: ms.Tensor): """ assert len(field.shape) == 3, "input must be a 3D scalar field" - cases = self.cases - masks = self.masks + # In PyTorch, cases and masks are registered buffers which could be loaded by ckpt + # and their data-type would NOT be changed when pipeline is loaded with `torch_dtype=tgt_dtype`. + # In MindSpore we define them as Parameter which could be loaded while their dtype would + # be CHANGED. Therefore we cast them to original data-type manually. + cases = self.cases.long() + masks = self.masks.bool() grid_size = field.shape grid_size_tensor = ms.Tensor(grid_size).to(size.dtype) @@ -541,15 +545,15 @@ def construct(self, field: ms.Tensor, min_point: ms.Tensor, size: ms.Tensor): # Create bitmasks between 0 and 255 (inclusive) indicating the state # of the eight corners of each cube. bitmasks = (field > 0).to(ms.uint8) - bitmasks = bitmasks[:-1, :, :] | (bitmasks[1:, :, :] << 1) - bitmasks = bitmasks[:, :-1, :] | (bitmasks[:, 1:, :] << 2) - bitmasks = bitmasks[:, :, :-1] | (bitmasks[:, :, 1:] << 4) + bitmasks = bitmasks[:-1, :, :] | (bitmasks[1:, :, :] * 2**1) + bitmasks = bitmasks[:, :-1, :] | (bitmasks[:, 1:, :] * 2**2) + bitmasks = bitmasks[:, :, :-1] | (bitmasks[:, :, 1:] * 2**4) # Compute corner coordinates across the entire grid. corner_coords = ops.zeros(grid_size + (3,), dtype=field.dtype) - corner_coords[range(grid_size[0]), :, :, 0] = ops.arange(grid_size[0], dtype=field.dtype)[:, None, None] - corner_coords[:, range(grid_size[1]), :, 1] = ops.arange(grid_size[1], dtype=field.dtype)[:, None] - corner_coords[:, :, range(grid_size[2]), 2] = ops.arange(grid_size[2], dtype=field.dtype) + corner_coords[:, :, :, 0] += ops.arange(grid_size[0], dtype=field.dtype)[:, None, None] + corner_coords[:, :, :, 1] += ops.arange(grid_size[1], dtype=field.dtype)[None, :, None] + corner_coords[:, :, :, 2] += ops.arange(grid_size[2], dtype=field.dtype)[None, None, :] # Compute all vertices across all edges in the grid, even though we will # throw some out later. We have (X-1)*Y*Z + X*(Y-1)*Z + X*Y*(Z-1) vertices. @@ -566,9 +570,9 @@ def construct(self, field: ms.Tensor, min_point: ms.Tensor, size: ms.Tensor): # Create a flat array of [X, Y, Z] indices for each cube. cube_indices = ops.zeros((grid_size[0] - 1, grid_size[1] - 1, grid_size[2] - 1, 3), dtype=ms.int64) - cube_indices[range(grid_size[0] - 1), :, :, 0] = ops.arange(grid_size[0] - 1)[:, None, None] - cube_indices[:, range(grid_size[1] - 1), :, 1] = ops.arange(grid_size[1] - 1)[:, None] - cube_indices[:, :, range(grid_size[2] - 1), 2] = ops.arange(grid_size[2] - 1) + cube_indices[:, :, :, 0] += ops.arange(grid_size[0] - 1)[:, None, None] + cube_indices[:, :, :, 1] += ops.arange(grid_size[1] - 1)[None, :, None] + cube_indices[:, :, :, 2] += ops.arange(grid_size[2] - 1)[None, None, :] flat_cube_indices = cube_indices.reshape(-1, 3) # Create a flat array mapping each cube to 12 global edge indices. @@ -577,7 +581,7 @@ def construct(self, field: ms.Tensor, min_point: ms.Tensor, size: ms.Tensor): # Apply the LUT to figure out the triangles. flat_bitmasks = bitmasks.reshape(-1).long() # must cast to long for indexing to believe this not a mask local_tris = cases[flat_bitmasks] - local_masks = masks[flat_bitmasks] + local_masks = masks.long()[flat_bitmasks].bool() # bool tensor couldn't sliced like this # Compute the global edge indices for the triangles. global_tris = ops.gather_elements(edge_indices, 1, local_tris.reshape(local_tris.shape[0], -1)).reshape( local_tris.shape @@ -1002,10 +1006,12 @@ def decode_to_mesh( # create grid 128 x 128 x 128 # - force a negative border around the SDFs to close off all the models. full_grid = ops.zeros( - 1, - grid_size + 2, - grid_size + 2, - grid_size + 2, + size=( + 1, + grid_size + 2, + grid_size + 2, + grid_size + 2, + ), dtype=fields.dtype, ) full_grid = full_grid.fill(-1.0) diff --git a/mindone/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/mindone/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 6fec8777ac..859152f780 100644 --- a/mindone/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/mindone/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -24,6 +24,7 @@ from mindone.transformers import CLIPTextModel, CLIPVisionModelWithProjection +from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...configuration_utils import FrozenDict from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin @@ -58,8 +59,8 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 """ - std_text = noise_pred_text.std(axis=list(range(1, noise_pred_text.ndim)), keepdims=True) - std_cfg = noise_cfg.std(axis=list(range(1, noise_cfg.ndim)), keepdims=True) + std_text = noise_pred_text.std(axis=tuple(range(1, noise_pred_text.ndim)), keepdims=True) + std_cfg = noise_cfg.std(axis=tuple(range(1, noise_cfg.ndim)), keepdims=True) # rescale the results from guidance (fixes overexposure) noise_pred_rescaled = noise_cfg * (std_text / std_cfg) # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images @@ -71,6 +72,7 @@ def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, **kwargs, ): """ @@ -81,17 +83,21 @@ def retrieve_timesteps( scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): - The number of diffusion steps used when generating samples with a pre-trained model. If used, - `timesteps` must be `None`. + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. timesteps (`List[int]`, *optional*): - Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default - timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` - must be `None`. + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[ms.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: @@ -102,6 +108,16 @@ def retrieve_timesteps( scheduler.set_timesteps(timesteps=timesteps, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, **kwargs) timesteps = scheduler.timesteps @@ -440,9 +456,10 @@ def encode_prompt( negative_prompt_embeds = negative_prompt_embeds.tile((1, num_images_per_prompt, 1)) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - if isinstance(self, LoraLoaderMixin): - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder is not None: + if isinstance(self, LoraLoaderMixin): + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) return prompt_embeds, negative_prompt_embeds @@ -455,11 +472,9 @@ def encode_image(self, image, num_images_per_prompt, output_hidden_states=None): image = image.to(dtype=dtype) if output_hidden_states: - image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True)[2][-2] image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_enc_hidden_states = self.image_encoder( - ops.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] + uncond_image_enc_hidden_states = self.image_encoder(ops.zeros_like(image), output_hidden_states=True)[2][-2] uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( num_images_per_prompt, dim=0 ) @@ -638,7 +653,12 @@ def check_inputs( ) def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, generator, latents=None): - shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" @@ -657,7 +677,9 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype return latents # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding - def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=ms.float32): + def get_guidance_scale_embedding( + self, w: ms.Tensor, embedding_dim: int = 512, dtype: ms.Type = ms.float32 + ) -> ms.Tensor: """ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 @@ -666,7 +688,7 @@ def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=ms.float32): Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. embedding_dim (`int`, *optional*, defaults to 512): Dimension of the embeddings to generate. - dtype: + dtype (`ms.dtype`, *optional*, defaults to `ms.float32`): Data type of the generated embeddings. Returns: @@ -723,6 +745,7 @@ def __call__( width: Optional[int] = None, num_inference_steps: int = 50, timesteps: List[int] = None, + sigmas: List[float] = None, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, @@ -738,7 +761,9 @@ def __call__( cross_attention_kwargs: Optional[Dict[str, Any]] = None, guidance_rescale: float = 0.0, clip_skip: Optional[int] = None, - callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], **kwargs, ): @@ -759,6 +784,10 @@ def __call__( Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. @@ -785,10 +814,10 @@ def __call__( not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. ip_adapter_image_embeds (`List[ms.Tensor]`, *optional*): - Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters. - Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding - if `do_classifier_free_guidance` is set to `True`. - If not provided, embeddings are computed from the `ip_adapter_image` input argument. + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `False`): @@ -804,11 +833,11 @@ def __call__( clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. - callback_on_step_end (`Callable`, *optional*): - A function that calls at the end of each denoising steps during the inference. The function is called - with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, - callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by - `callback_on_step_end_tensor_inputs`. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the @@ -840,6 +869,9 @@ def __call__( "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", ) + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor @@ -902,7 +934,7 @@ def __call__( ) # 4. Prepare timesteps - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, timesteps) + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, timesteps, sigmas) # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels diff --git a/mindone/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py b/mindone/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py index a5d75b036f..76f50e28d8 100644 --- a/mindone/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +++ b/mindone/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py @@ -347,9 +347,10 @@ def encode_prompt( negative_prompt_embeds = negative_prompt_embeds.tile((1, num_images_per_prompt, 1)) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - if isinstance(self, LoraLoaderMixin): - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder is not None: + if isinstance(self, LoraLoaderMixin): + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) return prompt_embeds, negative_prompt_embeds diff --git a/mindone/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py b/mindone/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py index 6ab7be36a4..f2c7ae2393 100644 --- a/mindone/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +++ b/mindone/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py @@ -225,7 +225,12 @@ def check_inputs(self, image, height, width, callback_steps): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, generator, latents=None): - shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" diff --git a/mindone/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/mindone/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index fb6e2806c8..56a3b31c13 100644 --- a/mindone/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/mindone/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -25,6 +25,7 @@ from mindone.transformers import CLIPTextModel, CLIPVisionModelWithProjection +from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...configuration_utils import FrozenDict from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin @@ -106,6 +107,7 @@ def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, **kwargs, ): """ @@ -116,17 +118,21 @@ def retrieve_timesteps( scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): - The number of diffusion steps used when generating samples with a pre-trained model. If used, - `timesteps` must be `None`. + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. timesteps (`List[int]`, *optional*): - Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default - timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` - must be `None`. + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[ms.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: @@ -137,6 +143,16 @@ def retrieve_timesteps( scheduler.set_timesteps(timesteps=timesteps, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, **kwargs) timesteps = scheduler.timesteps @@ -475,9 +491,10 @@ def encode_prompt( negative_prompt_embeds = negative_prompt_embeds.tile((1, num_images_per_prompt, 1)) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - if isinstance(self, LoraLoaderMixin): - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder is not None: + if isinstance(self, LoraLoaderMixin): + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) return prompt_embeds, negative_prompt_embeds @@ -741,7 +758,9 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt return latents # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding - def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=ms.float32): + def get_guidance_scale_embedding( + self, w: ms.Tensor, embedding_dim: int = 512, dtype: ms.Type = ms.float32 + ) -> ms.Tensor: """ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 @@ -750,7 +769,7 @@ def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=ms.float32): Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. embedding_dim (`int`, *optional*, defaults to 512): Dimension of the embeddings to generate. - dtype: + dtype (`ms.dtype`, *optional*, defaults to `ms.float32`): Data type of the generated embeddings. Returns: @@ -803,6 +822,7 @@ def __call__( strength: float = 0.8, num_inference_steps: int = 50, timesteps: List[int] = None, + sigmas: List[float] = None, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, @@ -816,7 +836,9 @@ def __call__( return_dict: bool = False, cross_attention_kwargs: Optional[Dict[str, Any]] = None, clip_skip: Optional[int] = None, - callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], **kwargs, ): @@ -845,6 +867,10 @@ def __call__( Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. @@ -882,11 +908,11 @@ def __call__( clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. - callback_on_step_end (`Callable`, *optional*): - A function that calls at the end of each denoising steps during the inference. The function is called - with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, - callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by - `callback_on_step_end_tensor_inputs`. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the @@ -918,6 +944,9 @@ def __call__( "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", ) + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, @@ -978,7 +1007,7 @@ def __call__( image = self.image_processor.preprocess(image) # 5. set timesteps - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, timesteps) + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, timesteps, sigmas) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) diff --git a/mindone/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/mindone/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 6227f9e9c0..c6cbaa78f7 100644 --- a/mindone/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/mindone/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -25,6 +25,7 @@ from mindone.transformers import CLIPTextModel, CLIPVisionModelWithProjection +from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...configuration_utils import FrozenDict from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin @@ -183,6 +184,7 @@ def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, **kwargs, ): """ @@ -193,17 +195,21 @@ def retrieve_timesteps( scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): - The number of diffusion steps used when generating samples with a pre-trained model. If used, - `timesteps` must be `None`. + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. timesteps (`List[int]`, *optional*): - Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default - timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` - must be `None`. + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[ms.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: @@ -214,6 +220,16 @@ def retrieve_timesteps( scheduler.set_timesteps(timesteps=timesteps, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, **kwargs) timesteps = scheduler.timesteps @@ -560,9 +576,10 @@ def encode_prompt( negative_prompt_embeds = negative_prompt_embeds.tile((1, num_images_per_prompt, 1)) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - if isinstance(self, LoraLoaderMixin): - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder is not None: + if isinstance(self, LoraLoaderMixin): + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) return prompt_embeds, negative_prompt_embeds @@ -576,11 +593,9 @@ def encode_image(self, image, num_images_per_prompt, output_hidden_states=None): image = image.to(dtype=dtype) if output_hidden_states: - image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True)[2][-2] image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_enc_hidden_states = self.image_encoder( - ops.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] + uncond_image_enc_hidden_states = self.image_encoder(ops.zeros_like(image), output_hidden_states=True)[2][-2] uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( num_images_per_prompt, dim=0 ) @@ -780,7 +795,12 @@ def prepare_latents( return_noise=False, return_image_latents=False, ): - shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" @@ -892,7 +912,9 @@ def get_timesteps(self, num_inference_steps, strength): return timesteps, num_inference_steps - t_start # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding - def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=ms.float32): + def get_guidance_scale_embedding( + self, w: ms.Tensor, embedding_dim: int = 512, dtype: ms.Type = ms.float32 + ) -> ms.Tensor: """ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 @@ -901,7 +923,7 @@ def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=ms.float32): Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. embedding_dim (`int`, *optional*, defaults to 512): Dimension of the embeddings to generate. - dtype: + dtype (`ms.dtype`, *optional*, defaults to `ms.float32`): Data type of the generated embeddings. Returns: @@ -959,6 +981,7 @@ def __call__( strength: float = 1.0, num_inference_steps: int = 50, timesteps: List[int] = None, + sigmas: List[float] = None, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, @@ -973,7 +996,9 @@ def __call__( return_dict: bool = False, cross_attention_kwargs: Optional[Dict[str, Any]] = None, clip_skip: int = None, - callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], **kwargs, ): @@ -1003,11 +1028,11 @@ def __call__( The width in pixels of the generated image. padding_mask_crop (`int`, *optional*, defaults to `None`): The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to - image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region with - the same aspect ration of the image and contains all masked area, and then expand that area based on - `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before + image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region + with the same aspect ration of the image and contains all masked area, and then expand that area based + on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before resizing to the original image size for inpainting. This is useful when the masked area is small while - the image is large and contain information inreleant for inpainging, such as background. + the image is large and contain information irrelevant for inpainting, such as background. strength (`float`, *optional*, defaults to 1.0): Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a starting point and more noise is added the higher the `strength`. The number of denoising steps depends @@ -1021,6 +1046,10 @@ def __call__( Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. @@ -1062,11 +1091,11 @@ def __call__( clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. - callback_on_step_end (`Callable`, *optional*): - A function that calls at the end of each denoising steps during the inference. The function is called - with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, - callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by - `callback_on_step_end_tensor_inputs`. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the @@ -1125,6 +1154,9 @@ def __call__( "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", ) + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor @@ -1190,7 +1222,7 @@ def __call__( ) # 4. set timesteps - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, timesteps) + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, timesteps, sigmas) timesteps, num_inference_steps = self.get_timesteps( num_inference_steps=num_inference_steps, strength=strength, diff --git a/mindone/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py b/mindone/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py index 67b51d3228..198dee2551 100644 --- a/mindone/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +++ b/mindone/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py @@ -606,11 +606,9 @@ def encode_image(self, image, num_images_per_prompt, output_hidden_states=None): image = image.to(dtype=dtype) if output_hidden_states: - image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True)[-1][-2] + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True)[2][-2] image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_enc_hidden_states = self.image_encoder(ops.zeros_like(image), output_hidden_states=True)[-1][ - -2 - ] + uncond_image_enc_hidden_states = self.image_encoder(ops.zeros_like(image), output_hidden_states=True)[2][-2] uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( num_images_per_prompt, dim=0 ) diff --git a/mindone/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py b/mindone/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py index edc8134508..833050d469 100644 --- a/mindone/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +++ b/mindone/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py @@ -358,9 +358,10 @@ def encode_prompt( negative_prompt_embeds = negative_prompt_embeds.tile((1, num_images_per_prompt, 1)) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - if isinstance(self, LoraLoaderMixin): - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder is not None: + if isinstance(self, LoraLoaderMixin): + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) return prompt_embeds, negative_prompt_embeds @@ -449,7 +450,7 @@ def check_inputs( ) # verify batch size of prompt and image are same if image is a list or tensor or numpy array - if isinstance(image, list) or isinstance(image, ms.Tensor) or isinstance(image, np.ndarray): + if isinstance(image, (list, np.ndarray, ms.Tensor)): if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): diff --git a/mindone/diffusers/pipelines/stable_diffusion/safety_checker.py b/mindone/diffusers/pipelines/stable_diffusion/safety_checker.py index 05960bcc56..5ec8b27edb 100644 --- a/mindone/diffusers/pipelines/stable_diffusion/safety_checker.py +++ b/mindone/diffusers/pipelines/stable_diffusion/safety_checker.py @@ -30,6 +30,7 @@ def cosine_distance(image_embeds, text_embeds): class StableDiffusionSafetyChecker(MSPreTrainedModel): config_class = CLIPConfig + main_input_name = "clip_input" _keys_to_ignore_on_load_unexpected = ["vision_model.vision_model.embeddings.position_ids"] _no_split_modules = ["CLIPEncoderLayer"] diff --git a/mindone/diffusers/pipelines/stable_diffusion_3/__init__.py b/mindone/diffusers/pipelines/stable_diffusion_3/__init__.py index 04d6f473cc..7e9b9843bc 100644 --- a/mindone/diffusers/pipelines/stable_diffusion_3/__init__.py +++ b/mindone/diffusers/pipelines/stable_diffusion_3/__init__.py @@ -4,9 +4,11 @@ _import_structure = {"pipeline_output": ["StableDiffusion3PipelineOutput"]} _import_structure["pipeline_stable_diffusion_3"] = ["StableDiffusion3Pipeline"] +_import_structure["pipeline_stable_diffusion_3_img2img"] = ["StableDiffusion3Img2ImgPipeline"] if TYPE_CHECKING: from .pipeline_stable_diffusion_3 import StableDiffusion3Pipeline + from .pipeline_stable_diffusion_3_img2img import StableDiffusion3Img2ImgPipeline else: import sys diff --git a/mindone/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py b/mindone/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py new file mode 100644 index 0000000000..50adc998cb --- /dev/null +++ b/mindone/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py @@ -0,0 +1,871 @@ +# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable, Dict, List, Optional, Union + +import numpy as np +import PIL.Image +from transformers import CLIPTokenizer, T5TokenizerFast + +import mindspore as ms +from mindspore import ops + +from ....transformers import CLIPTextModelWithProjection, T5EncoderModel +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...models.autoencoders import AutoencoderKL +from ...models.transformers import SD3Transformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import logging +from ...utils.mindspore_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import StableDiffusion3PipelineOutput + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import mindspore + + >>> from mindone.diffusers import StableDiffusion3Img2ImgPipeline + >>> from mindone.diffusers.utils import load_image + + >>> model_id_or_path = "stabilityai/stable-diffusion-3-medium-diffusers" + >>> pipe = StableDiffusion3Img2ImgPipeline.from_pretrained(model_id_or_path, mindspore_dtype=mindspore.float16) + + >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" + >>> init_image = load_image(url).resize((512, 512)) + + >>> prompt = "cat wizard, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney, 8k" + + >>> images = pipe(prompt=prompt, image=init_image, strength=0.95, guidance_scale=7.5)[0][0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents(vae, encoder_output: ms.Tensor, sample_mode: str = "sample"): + if sample_mode == "sample": + return vae.diag_gauss_dist.sample(encoder_output) + elif sample_mode == "argmax": + return vae.diag_gauss_dist.mode(encoder_output) + # This branch is not needed because the encoder_output type is ms.Tensor as per AutoencoderKLOutput change + # elif hasattr(encoder_output, "latents"): + # return encoder_output.latents + else: + return encoder_output + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[ms.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and + the second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class StableDiffusion3Img2ImgPipeline(DiffusionPipeline): + r""" + Args: + transformer ([`SD3Transformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModelWithProjection`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant, + with an additional added projection layer that is initialized with a diagonal matrix with the `hidden_size` + as its dimension. + text_encoder_2 ([`CLIPTextModelWithProjection`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. + text_encoder_3 ([`T5EncoderModel`]): + Frozen text-encoder. Stable Diffusion 3 uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the + [t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`CLIPTokenizer`): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_3 (`T5TokenizerFast`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"] + + def __init__( + self, + transformer: SD3Transformer2DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer_2: CLIPTokenizer, + text_encoder_3: T5EncoderModel, + tokenizer_3: T5TokenizerFast, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + text_encoder_3=text_encoder_3, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + tokenizer_3=tokenizer_3, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, vae_latent_channels=self.vae.config.latent_channels + ) + self.tokenizer_max_length = self.tokenizer.model_max_length + self.default_sample_size = self.transformer.config.sample_size + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + dtype=None, + ): + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if self.text_encoder_3 is None: + return ops.zeros( + (batch_size, self.tokenizer_max_length, self.transformer.config.joint_attention_dim), + dtype=dtype, + ) + + text_inputs = self.tokenizer_3( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + add_special_tokens=True, + return_tensors="np", + ) + text_input_ids = ms.Tensor.from_numpy(text_inputs.input_ids) + untruncated_ids = ms.Tensor.from_numpy( + self.tokenizer_3(prompt, padding="longest", return_tensors="np").input_ids + ) + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not ops.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_3(text_input_ids)[0] + + dtype = self.text_encoder_3.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.tile((1, num_images_per_prompt, 1)) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_clip_prompt_embeds + def _get_clip_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + clip_skip: Optional[int] = None, + clip_model_index: int = 0, + ): + clip_tokenizers = [self.tokenizer, self.tokenizer_2] + clip_text_encoders = [self.text_encoder, self.text_encoder_2] + + tokenizer = clip_tokenizers[clip_model_index] + text_encoder = clip_text_encoders[clip_model_index] + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_tensors="np", + ) + + text_input_ids = ms.Tensor.from_numpy(text_inputs.input_ids) + untruncated_ids = ms.Tensor.from_numpy(tokenizer(prompt, padding="longest", return_tensors="np").input_ids) + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not ops.equal(text_input_ids, untruncated_ids): + removed_text = tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = text_encoder(text_input_ids, output_hidden_states=True) + pooled_prompt_embeds = prompt_embeds[0] + + if clip_skip is None: + prompt_embeds = prompt_embeds[2][-2] + else: + prompt_embeds = prompt_embeds[2][-(clip_skip + 2)] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype) + + _, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.tile((1, num_images_per_prompt, 1)) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.tile((1, num_images_per_prompt, 1)) + pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds, pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]], + prompt_3: Union[str, List[str]], + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt_3: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[ms.Tensor] = None, + negative_prompt_embeds: Optional[ms.Tensor] = None, + pooled_prompt_embeds: Optional[ms.Tensor] = None, + negative_pooled_prompt_embeds: Optional[ms.Tensor] = None, + clip_skip: Optional[int] = None, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + prompt_3 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is + used in all text-encoders + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and + `text_encoder_3`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`ms.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`ms.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`ms.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`ms.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + prompt_3 = prompt_3 or prompt + prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3 + + prompt_embed, pooled_prompt_embed = self._get_clip_prompt_embeds( + prompt=prompt, + num_images_per_prompt=num_images_per_prompt, + clip_skip=clip_skip, + clip_model_index=0, + ) + prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds( + prompt=prompt_2, + num_images_per_prompt=num_images_per_prompt, + clip_skip=clip_skip, + clip_model_index=1, + ) + clip_prompt_embeds = ops.cat([prompt_embed, prompt_2_embed], axis=-1) + + t5_prompt_embed = self._get_t5_prompt_embeds( + prompt=prompt_3, + num_images_per_prompt=num_images_per_prompt, + ) + + clip_prompt_embeds = ops.pad( + clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1]) + ) + + prompt_embeds = ops.cat([clip_prompt_embeds, t5_prompt_embed], axis=-2) + pooled_prompt_embeds = ops.cat([pooled_prompt_embed, pooled_prompt_2_embed], axis=-1) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + negative_prompt_3 = negative_prompt_3 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + negative_prompt_3 = ( + batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3 + ) + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embed, negative_pooled_prompt_embed = self._get_clip_prompt_embeds( + negative_prompt, + num_images_per_prompt=num_images_per_prompt, + clip_skip=None, + clip_model_index=0, + ) + negative_prompt_2_embed, negative_pooled_prompt_2_embed = self._get_clip_prompt_embeds( + negative_prompt_2, + num_images_per_prompt=num_images_per_prompt, + clip_skip=None, + clip_model_index=1, + ) + negative_clip_prompt_embeds = ops.cat([negative_prompt_embed, negative_prompt_2_embed], axis=-1) + + t5_negative_prompt_embed = self._get_t5_prompt_embeds( + prompt=negative_prompt_3, + num_images_per_prompt=num_images_per_prompt, + ) + + negative_clip_prompt_embeds = ops.pad( + negative_clip_prompt_embeds, + (0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]), + ) + + negative_prompt_embeds = ops.cat([negative_clip_prompt_embeds, t5_negative_prompt_embed], axis=-2) + negative_pooled_prompt_embeds = ops.cat( + [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], axis=-1 + ) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + def check_inputs( + self, + prompt, + prompt_2, + prompt_3, + strength, + negative_prompt=None, + negative_prompt_2=None, + negative_prompt_3=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" # noqa: E501 + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_3 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_3`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + elif prompt_3 is not None and (not isinstance(prompt_3, str) and not isinstance(prompt_3, list)): + raise ValueError(f"`prompt_3` has to be of type `str` or `list` but is {type(prompt_3)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_3 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_3`: {negative_prompt_3} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." # noqa: E501 + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." # noqa: E501 + ) + + def get_timesteps(self, num_inference_steps, strength): + # get the original timestep using init_timestep + init_timestep = min(num_inference_steps * strength, num_inference_steps) + + t_start = int(max(num_inference_steps - init_timestep, 0)) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, generator=None): + if not isinstance(image, (ms.Tensor, PIL.Image.Image, list)): + raise ValueError(f"`image` has to be of type `ms.Tensor`, `PIL.Image.Image` or list but is {type(image)}") + + image = image.to(dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + if image.shape[1] == self.vae.config.latent_channels: + init_latents = image + + else: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + elif isinstance(generator, list): + init_latents = [ + retrieve_latents(self.vae, self.vae.encode(image[i : i + 1])[0]) for i in range(batch_size) + ] + init_latents = ops.cat(init_latents, axis=0) + else: + init_latents = retrieve_latents(self.vae, self.vae.encode(image)[0]) + + init_latents = (init_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = ops.cat([init_latents] * additional_image_per_prompt, axis=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = ops.cat([init_latents], axis=0) + + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, dtype=dtype) + + # get latents + init_latents = self.scheduler.scale_noise(init_latents, timestep, noise) + latents = init_latents.to(dtype=dtype) + + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + prompt_3: Optional[Union[str, List[str]]] = None, + image: PipelineImageInput = None, + strength: float = 0.6, + num_inference_steps: int = 50, + timesteps: List[int] = None, + guidance_scale: float = 7.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt_3: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[np.random.Generator, List[np.random.Generator]]] = None, + latents: Optional[ms.Tensor] = None, + prompt_embeds: Optional[ms.Tensor] = None, + negative_prompt_embeds: Optional[ms.Tensor] = None, + pooled_prompt_embeds: Optional[ms.Tensor] = None, + negative_pooled_prompt_embeds: Optional[ms.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = False, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead + prompt_3 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is + will be used instead + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used instead + negative_prompt_3 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and + `text_encoder_3`. If not defined, `negative_prompt` is used instead + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`np.random.Generator` or `List[np.random.Generator]`, *optional*): + One or a list of [numpy generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`ms.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`ms.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`ms.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`ms.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`ms.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `False`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + prompt_3, + strength, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_3=prompt_3, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + clip_skip=self.clip_skip, + num_images_per_prompt=num_images_per_prompt, + ) + + if self.do_classifier_free_guidance: + prompt_embeds = ops.cat([negative_prompt_embeds, prompt_embeds], axis=0) + pooled_prompt_embeds = ops.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], axis=0) + + # 3. Preprocess image + image = self.image_processor.preprocess(image) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, timesteps) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength) + latent_timestep = timesteps[:1].tile((batch_size * num_inference_steps,)) + + # 5. Prepare latent variables + if latents is None: + latents = self.prepare_latents( + image, + latent_timestep, + batch_size, + num_images_per_prompt, + prompt_embeds.dtype, + generator, + ) + + # 6. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = ops.cat([latents] * 2) if self.do_classifier_free_guidance else latents + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.broadcast_to((latent_model_input.shape[0],)) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + pooled_projections=pooled_prompt_embeds, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if output_type == "latent": + image = latents + + else: + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + if not return_dict: + return (image,) + + return StableDiffusion3PipelineOutput(images=image) diff --git a/mindone/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py b/mindone/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py index 761f26603c..50f6c850e5 100644 --- a/mindone/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +++ b/mindone/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py @@ -172,7 +172,7 @@ def preprocess(image): def preprocess_mask(mask, batch_size: int = 1): if not isinstance(mask, ms.Tensor): # preprocess mask - if isinstance(mask, PIL.Image.Image) or isinstance(mask, np.ndarray): + if isinstance(mask, (PIL.Image.Image, np.ndarray)): mask = [mask] if isinstance(mask, list): @@ -559,9 +559,10 @@ def encode_prompt( negative_prompt_embeds = negative_prompt_embeds.tile((1, num_images_per_prompt, 1)) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - if isinstance(self, LoraLoaderMixin): - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder is not None: + if isinstance(self, LoraLoaderMixin): + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) return prompt_embeds, negative_prompt_embeds @@ -715,7 +716,12 @@ def get_inverse_timesteps(self, num_inference_steps, strength): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, generator, latents=None): - shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" @@ -1225,7 +1231,7 @@ def __call__( callback: Optional[Callable[[int, int, ms.Tensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, - clip_ckip: int = None, + clip_skip: int = None, ): r""" The call function to the pipeline for generation. @@ -1345,7 +1351,7 @@ def __call__( prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, - clip_skip=clip_ckip, + clip_skip=clip_skip, ) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch diff --git a/mindone/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py b/mindone/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py index a37ae79f1f..82526382ff 100644 --- a/mindone/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +++ b/mindone/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py @@ -362,9 +362,10 @@ def encode_prompt( negative_prompt_embeds = negative_prompt_embeds.tile((1, num_images_per_prompt, 1)) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - if isinstance(self, LoraLoaderMixin): - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder is not None: + if isinstance(self, LoraLoaderMixin): + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) return prompt_embeds, negative_prompt_embeds @@ -457,14 +458,19 @@ def check_inputs( ) if len(gligen_phrases) != len(gligen_boxes): - ValueError( + raise ValueError( "length of `gligen_phrases` and `gligen_boxes` has to be same, but" f" got: `gligen_phrases` {len(gligen_phrases)} != `gligen_boxes` {len(gligen_boxes)}" ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, generator, latents=None): - shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" @@ -666,7 +672,7 @@ def __call__( timesteps = self.scheduler.timesteps # 5. Prepare latent variables - num_channels_latents = self.unet.in_channels + num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, @@ -699,7 +705,7 @@ def __call__( # we represent the location information as (xmin,ymin,xmax,ymax) boxes = ops.zeros((max_objs, 4), dtype=self.text_encoder.dtype) boxes[:n_objs] = ms.Tensor(gligen_boxes) - text_embeddings = ops.zeros((max_objs, self.unet.cross_attention_dim), dtype=self.text_encoder.dtype) + text_embeddings = ops.zeros((max_objs, self.unet.config.cross_attention_dim), dtype=self.text_encoder.dtype) text_embeddings[:n_objs] = _text_embeddings # Generate a mask for each object that is entity described by phrases masks = ops.zeros((max_objs,), dtype=self.text_encoder.dtype) diff --git a/mindone/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py b/mindone/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py index 2bf2eb2aea..154a987b4a 100644 --- a/mindone/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +++ b/mindone/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py @@ -389,9 +389,10 @@ def encode_prompt( negative_prompt_embeds = negative_prompt_embeds.tile((1, num_images_per_prompt, 1)) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - if isinstance(self, LoraLoaderMixin): - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder is not None: + if isinstance(self, LoraLoaderMixin): + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) return prompt_embeds, negative_prompt_embeds @@ -490,7 +491,12 @@ def check_inputs( # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, generator, latents=None): - shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" @@ -837,7 +843,7 @@ def __call__( timesteps = self.scheduler.timesteps # 5. Prepare latent variables - num_channels_latents = self.unet.in_channels + num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, diff --git a/mindone/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/mindone/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 401c62fdc8..9eedfb7e95 100644 --- a/mindone/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/mindone/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -22,6 +22,7 @@ from mindone.transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPVisionModelWithProjection +from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import ( FromSingleFileMixin, @@ -60,8 +61,8 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 """ - std_text = noise_pred_text.std(axis=list(range(1, noise_pred_text.ndim)), keepdims=True) - std_cfg = noise_cfg.std(axis=list(range(1, noise_cfg.ndim)), keepdims=True) + std_text = noise_pred_text.std(axis=tuple(range(1, noise_pred_text.ndim)), keepdims=True) + std_cfg = noise_cfg.std(axis=tuple(range(1, noise_cfg.ndim)), keepdims=True) # rescale the results from guidance (fixes overexposure) noise_pred_rescaled = noise_cfg * (std_text / std_cfg) # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images @@ -74,6 +75,7 @@ def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, **kwargs, ): """ @@ -84,17 +86,21 @@ def retrieve_timesteps( scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): - The number of diffusion steps used when generating samples with a pre-trained model. If used, - `timesteps` must be `None`. + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. timesteps (`List[int]`, *optional*): - Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default - timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` - must be `None`. + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[ms.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: @@ -105,6 +111,16 @@ def retrieve_timesteps( scheduler.set_timesteps(timesteps=timesteps, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, **kwargs) timesteps = scheduler.timesteps @@ -551,7 +567,12 @@ def check_inputs( # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, generator, latents=None): - shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" @@ -589,7 +610,9 @@ def upcast_vae(self): self.vae.to(dtype=ms.float32) # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding - def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=ms.float32): + def get_guidance_scale_embedding( + self, w: ms.Tensor, embedding_dim: int = 512, dtype: ms.Type = ms.float32 + ) -> ms.Tensor: """ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 @@ -598,7 +621,7 @@ def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=ms.float32): Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. embedding_dim (`int`, *optional*, defaults to 512): Dimension of the embeddings to generate. - dtype: + dtype (`ms.dtype`, *optional*, defaults to `ms.float32`): Data type of the generated embeddings. Returns: @@ -660,6 +683,7 @@ def __call__( width: Optional[int] = None, num_inference_steps: int = 50, timesteps: List[int] = None, + sigmas: List[float] = None, denoising_end: Optional[float] = None, guidance_scale: float = 5.0, negative_prompt: Optional[Union[str, List[str]]] = None, @@ -685,7 +709,9 @@ def __call__( negative_crops_coords_top_left: Tuple[int, int] = (0, 0), negative_target_size: Optional[Tuple[int, int]] = None, clip_skip: Optional[int] = None, - callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], **kwargs, ): @@ -716,6 +742,10 @@ def __call__( Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. denoising_end (`float`, *optional*): When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be completed before it is intentionally prematurely terminated. As a result, the returned sample will @@ -812,11 +842,11 @@ def __call__( as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. - callback_on_step_end (`Callable`, *optional*): - A function that calls at the end of each denoising steps during the inference. The function is called - with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, - callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by - `callback_on_step_end_tensor_inputs`. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the @@ -846,6 +876,9 @@ def __call__( "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", ) + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + # 0. Default height and width to unet height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor @@ -910,7 +943,7 @@ def __call__( ) # 4. Prepare timesteps - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, timesteps) + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, timesteps, sigmas) # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels diff --git a/mindone/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/mindone/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index 6a7c324d61..1c433ef3bd 100644 --- a/mindone/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/mindone/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -24,6 +24,7 @@ from mindone.transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPVisionModelWithProjection +from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import ( FromSingleFileMixin, @@ -66,8 +67,8 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 """ - std_text = noise_pred_text.std(axis=list(range(1, noise_pred_text.ndim)), keepdims=True) - std_cfg = noise_cfg.std(axis=list(range(1, noise_cfg.ndim)), keepdims=True) + std_text = noise_pred_text.std(axis=tuple(range(1, noise_pred_text.ndim)), keepdims=True) + std_cfg = noise_cfg.std(axis=tuple(range(1, noise_cfg.ndim)), keepdims=True) # rescale the results from guidance (fixes overexposure) noise_pred_rescaled = noise_cfg * (std_text / std_cfg) # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images @@ -94,6 +95,7 @@ def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, **kwargs, ): """ @@ -104,17 +106,21 @@ def retrieve_timesteps( scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): - The number of diffusion steps used when generating samples with a pre-trained model. If used, - `timesteps` must be `None`. + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. timesteps (`List[int]`, *optional*): - Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default - timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` - must be `None`. + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. Returns: - `Tuple[ms.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the - second element is the number of inference steps. + `Tuple[ms.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and + the second element is the number of inference steps. """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: @@ -125,6 +131,16 @@ def retrieve_timesteps( scheduler.set_timesteps(timesteps=timesteps, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, **kwargs) timesteps = scheduler.timesteps @@ -610,6 +626,12 @@ def prepare_latents( f"`image` has to be of type `mindspore.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) + latents_mean = latents_std = None + if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: + latents_mean = ms.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: + latents_std = ms.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) + image = image.to(dtype=dtype) batch_size = batch_size * num_images_per_prompt @@ -642,7 +664,12 @@ def prepare_latents( self.vae.to(dtype) init_latents = init_latents.to(dtype) - init_latents = self.vae.config.scaling_factor * init_latents + if latents_mean is not None and latents_std is not None: + latents_mean = latents_mean.to(dtype=dtype) + latents_std = latents_std.to(dtype=dtype) + init_latents = (init_latents - latents_mean) * self.vae.config.scaling_factor / latents_std + else: + init_latents = self.vae.config.scaling_factor * init_latents if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: # expand init_latents for batch_size @@ -728,7 +755,9 @@ def upcast_vae(self): self.vae.to(dtype=ms.float32) # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding - def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=ms.float32): + def get_guidance_scale_embedding( + self, w: ms.Tensor, embedding_dim: int = 512, dtype: ms.Type = ms.float32 + ) -> ms.Tensor: """ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 @@ -737,7 +766,7 @@ def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=ms.float32): Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. embedding_dim (`int`, *optional*, defaults to 512): Dimension of the embeddings to generate. - dtype: + dtype (`ms.dtype`, *optional*, defaults to `ms.float32`): Data type of the generated embeddings. Returns: @@ -803,6 +832,7 @@ def __call__( strength: float = 0.3, num_inference_steps: int = 50, timesteps: List[int] = None, + sigmas: List[float] = None, denoising_start: Optional[float] = None, denoising_end: Optional[float] = None, guidance_scale: float = 5.0, @@ -831,7 +861,9 @@ def __call__( aesthetic_score: float = 6.0, negative_aesthetic_score: float = 2.5, clip_skip: Optional[int] = None, - callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], **kwargs, ): @@ -861,6 +893,10 @@ def __call__( Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. denoising_start (`float`, *optional*): When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and @@ -976,11 +1012,11 @@ def __call__( clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. - callback_on_step_end (`Callable`, *optional*): - A function that calls at the end of each denoising steps during the inference. The function is called - with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, - callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by - `callback_on_step_end_tensor_inputs`. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the @@ -1010,6 +1046,9 @@ def __call__( "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", ) + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, @@ -1073,7 +1112,7 @@ def __call__( def denoising_value_valid(dnv): return isinstance(dnv, float) and 0 < dnv < 1 - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, timesteps) + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, timesteps, sigmas) timesteps, num_inference_steps = self.get_timesteps( num_inference_steps, strength, @@ -1082,16 +1121,18 @@ def denoising_value_valid(dnv): latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) add_noise = True if self.denoising_start is None else False + # 6. Prepare latent variables - latents = self.prepare_latents( - image, - latent_timestep, - batch_size, - num_images_per_prompt, - prompt_embeds.dtype, - generator, - add_noise, - ) + if latents is None: + latents = self.prepare_latents( + image, + latent_timestep, + batch_size, + num_images_per_prompt, + prompt_embeds.dtype, + generator, + add_noise, + ) # 7. Prepare extra step kwargs. extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) diff --git a/mindone/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/mindone/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index 6645c9303a..91eea66b07 100644 --- a/mindone/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/mindone/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -24,6 +24,7 @@ from mindone.transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPVisionModelWithProjection +from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import ( FromSingleFileMixin, @@ -74,8 +75,8 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 """ - std_text = noise_pred_text.std(axis=list(range(1, noise_pred_text.ndim)), keepdims=True) - std_cfg = noise_cfg.std(axis=list(range(1, noise_cfg.ndim)), keepdims=True) + std_text = noise_pred_text.std(axis=tuple(range(1, noise_pred_text.ndim)), keepdims=True) + std_cfg = noise_cfg.std(axis=tuple(range(1, noise_cfg.ndim)), keepdims=True) # rescale the results from guidance (fixes overexposure) noise_pred_rescaled = noise_cfg * (std_text / std_cfg) # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images @@ -102,29 +103,29 @@ def mask_pil_to_ms(mask, height, width): def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool = False): """ Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be - converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the + converted to ``ms.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the ``image`` and ``1`` for the ``mask``. - The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be - binarized (``mask > 0.5``) and cast to ``torch.float32`` too. + The ``image`` will be converted to ``ms.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be + binarized (``mask > 0.5``) and cast to ``ms.float32`` too. Args: - image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint. + image (Union[np.array, PIL.Image, ms.Tensor]): The image to inpaint. It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width`` - ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``. + ``ms.Tensor`` or a ``batch x channels x height x width`` ``ms.Tensor``. mask (_type_): The mask to apply to the image, i.e. regions to inpaint. It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width`` - ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``. + ``ms.Tensor`` or a ``batch x 1 x height x width`` ``ms.Tensor``. Raises: - ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask + ValueError: ``ms.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``ms.Tensor`` mask should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions. - TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not + TypeError: ``mask`` is a ``ms.Tensor`` but ``image`` is not (ot the other way around). Returns: - tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4 + tuple[ms.Tensor]: The pair (mask, masked_image) as ``ms.Tensor`` with 4 dimensions: ``batch x channels x height x width``. """ @@ -236,6 +237,7 @@ def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, **kwargs, ): """ @@ -246,17 +248,21 @@ def retrieve_timesteps( scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): - The number of diffusion steps used when generating samples with a pre-trained model. If used, - `timesteps` must be `None`. + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. timesteps (`List[int]`, *optional*): - Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default - timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` - must be `None`. + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. Returns: - `Tuple[ms.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the - second element is the number of inference steps. + `Tuple[ms.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and + the second element is the number of inference steps. """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: @@ -267,6 +273,16 @@ def retrieve_timesteps( scheduler.set_timesteps(timesteps=timesteps, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, **kwargs) timesteps = scheduler.timesteps @@ -745,7 +761,12 @@ def prepare_latents( return_noise=False, return_image_latents=False, ): - shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" @@ -963,7 +984,9 @@ def upcast_vae(self): self.vae.to(dtype=ms.float32) # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding - def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=ms.float32): + def get_guidance_scale_embedding( + self, w: ms.Tensor, embedding_dim: int = 512, dtype: ms.Type = ms.float32 + ) -> ms.Tensor: """ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 @@ -972,7 +995,7 @@ def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=ms.float32): Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. embedding_dim (`int`, *optional*, defaults to 512): Dimension of the embeddings to generate. - dtype: + dtype (`ms.dtype`, *optional*, defaults to `ms.float32`): Data type of the generated embeddings. Returns: @@ -1043,6 +1066,7 @@ def __call__( strength: float = 0.9999, num_inference_steps: int = 50, timesteps: List[int] = None, + sigmas: List[float] = None, denoising_start: Optional[float] = None, denoising_end: Optional[float] = None, guidance_scale: float = 7.5, @@ -1071,7 +1095,9 @@ def __call__( aesthetic_score: float = 6.0, negative_aesthetic_score: float = 2.5, clip_skip: Optional[int] = None, - callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], **kwargs, ): @@ -1104,11 +1130,12 @@ def __call__( [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and checkpoints that are not specifically fine-tuned on low resolutions. padding_mask_crop (`int`, *optional*, defaults to `None`): - The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to image and mask_image. If - `padding_mask_crop` is not `None`, it will first find a rectangular region with the same aspect ration of the image and - contains all masked area, and then expand that area based on `padding_mask_crop`. The image and mask_image will then be cropped based on - the expanded area before resizing to the original image size for inpainting. This is useful when the masked area is small - while the image is large and contain information inreleant for inpainging, such as background. + The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to + image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region + with the same aspect ration of the image and contains all masked area, and then expand that area based + on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before + resizing to the original image size for inpainting. This is useful when the masked area is small while + the image is large and contain information irrelevant for inpainting, such as background. strength (`float`, *optional*, defaults to 0.9999): Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the @@ -1124,6 +1151,10 @@ def __call__( Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. denoising_start (`float`, *optional*): When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and @@ -1234,11 +1265,11 @@ def __call__( clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. - callback_on_step_end (`Callable`, *optional*): - A function that calls at the end of each denoising steps during the inference. The function is called - with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, - callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by - `callback_on_step_end_tensor_inputs`. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the @@ -1268,6 +1299,9 @@ def __call__( "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", ) + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor @@ -1338,7 +1372,7 @@ def __call__( def denoising_value_valid(dnv): return isinstance(dnv, float) and 0 < dnv < 1 - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, timesteps) + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, timesteps, sigmas) timesteps, num_inference_steps = self.get_timesteps( num_inference_steps, strength, diff --git a/mindone/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py b/mindone/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py index a658c907cf..727c2160c7 100644 --- a/mindone/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +++ b/mindone/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py @@ -87,8 +87,8 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 """ - std_text = noise_pred_text.std(axis=list(range(1, noise_pred_text.ndim)), keepdims=True) - std_cfg = noise_cfg.std(axis=list(range(1, noise_cfg.ndim)), keepdims=True) + std_text = noise_pred_text.std(axis=tuple(range(1, noise_pred_text.ndim)), keepdims=True) + std_cfg = noise_cfg.std(axis=tuple(range(1, noise_cfg.ndim)), keepdims=True) # rescale the results from guidance (fixes overexposure) noise_pred_rescaled = noise_cfg * (std_text / std_cfg) # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images @@ -147,6 +147,8 @@ class StableDiffusionXLInstructPix2PixPipeline( Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to watermark output images. If not defined, it will default to True if the package is installed, otherwise no watermarker will be used. + is_cosxl_edit (`bool`, *optional*): + When set the image latents are scaled. """ model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" @@ -163,6 +165,7 @@ def __init__( scheduler: KarrasDiffusionSchedulers, force_zeros_for_empty_prompt: bool = True, add_watermarker: Optional[bool] = None, + is_cosxl_edit: Optional[bool] = False, ): super().__init__() @@ -179,6 +182,7 @@ def __init__( self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.default_sample_size = self.unet.config.sample_size + self.is_cosxl_edit = is_cosxl_edit if add_watermarker: logger.warning("watermarker is not supported!") @@ -445,7 +449,12 @@ def check_inputs( # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, generator, latents=None): - shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" @@ -477,8 +486,8 @@ def prepare_image_latents(self, image, batch_size, num_images_per_prompt, dtype, # make sure the VAE is in float32 mode, as it overflows in float16 needs_upcasting = self.vae.dtype == ms.float16 and self.vae.config.force_upcast if needs_upcasting: + image = image.float() self.upcast_vae() - image = image.to(next(iter(self.vae.post_quant_conv.get_parameters())).dtype) image_latents = retrieve_latents(self.vae, self.vae.encode(image)[0], sample_mode="argmax") @@ -511,6 +520,9 @@ def prepare_image_latents(self, image, batch_size, num_images_per_prompt, dtype, if image_latents.dtype != self.vae.dtype: image_latents = image_latents.to(dtype=self.vae.dtype) + if self.is_cosxl_edit: + image_latents = image_latents * self.vae.config.scaling_factor + return image_latents # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids diff --git a/mindone/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py b/mindone/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py index 19bfe2314c..7050cb0c3d 100644 --- a/mindone/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +++ b/mindone/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py @@ -25,11 +25,12 @@ from mindone.transformers import CLIPVisionModelWithProjection -from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...image_processor import PipelineImageInput from ...models import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel from ...schedulers import EulerDiscreteScheduler from ...utils import BaseOutput, logging from ...utils.mindspore_utils import randn_tensor +from ...video_processor import VideoProcessor from ..pipeline_utils import DiffusionPipeline logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -60,26 +61,61 @@ def _append_dims(x, target_dims): return x[(...,) + (None,) * dims_to_append] -# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid -def tensor2vid(video: ms.Tensor, processor: VaeImageProcessor, output_type: str = "np"): - batch_size, channels, num_frames, height, width = video.shape - outputs = [] - for batch_idx in range(batch_size): - batch_vid = video[batch_idx].permute((1, 0, 2, 3)) - batch_output = processor.postprocess(batch_vid, output_type) - - outputs.append(batch_output) - - if output_type == "np": - outputs = np.stack(outputs) - - elif output_type == "ms": - outputs = ops.stack(outputs) - - elif not output_type == "pil": - raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'ms', 'pil']") +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. - return outputs + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[ms.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and + the second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps @dataclass @@ -89,8 +125,8 @@ class StableVideoDiffusionPipelineOutput(BaseOutput): Args: frames (`[List[List[PIL.Image.Image]]`, `np.ndarray`, `ms.Tensor`]): - List of denoised PIL images of length `batch_size` or numpy array or torch tensor - of shape `(batch_size, num_frames, height, width, num_channels)`. + List of denoised PIL images of length `batch_size` or numpy array or ms tensor of shape `(batch_size, + num_frames, height, width, num_channels)`. """ frames: Union[List[List[PIL.Image.Image]], np.ndarray, ms.Tensor] @@ -107,7 +143,8 @@ class StableVideoDiffusionPipeline(DiffusionPipeline): vae ([`AutoencoderKLTemporalDecoder`]): Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. image_encoder ([`~transformers.CLIPVisionModelWithProjection`]): - Frozen CLIP image-encoder ([laion/CLIP-ViT-H-14-laion2B-s32B-b79K](https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K)). + Frozen CLIP image-encoder + ([laion/CLIP-ViT-H-14-laion2B-s32B-b79K](https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K)). unet ([`UNetSpatioTemporalConditionModel`]): A `UNetSpatioTemporalConditionModel` to denoise the encoded image latents. scheduler ([`EulerDiscreteScheduler`]): @@ -137,7 +174,7 @@ def __init__( feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.video_processor = VideoProcessor(do_resize=True, vae_scale_factor=self.vae_scale_factor) def _encode_image( self, @@ -148,8 +185,8 @@ def _encode_image( dtype = next(self.image_encoder.get_parameters()).dtype if not isinstance(image, ms.Tensor): - image = self.image_processor.pil_to_numpy(image) - image = self.image_processor.numpy_to_pt(image) + image = self.video_processor.pil_to_numpy(image) + image = self.video_processor.numpy_to_pt(image) # We normalize the image before resizing to match with the original implementation. # Then we unnormalize it after resizing. @@ -194,6 +231,9 @@ def _encode_vae_image( ): image_latents = self.vae.diag_gauss_dist.mode(self.vae.encode(image)[0]) + # duplicate image_latents for each generation per prompt, using mps friendly method + image_latents = image_latents.tile((num_videos_per_prompt, 1, 1, 1)) + if do_classifier_free_guidance: negative_image_latents = ops.zeros_like(image_latents) @@ -202,9 +242,6 @@ def _encode_vae_image( # to avoid doing two forward passes image_latents = ops.cat([negative_image_latents, image_latents]) - # duplicate image_latents for each generation per prompt, using mps friendly method - image_latents = image_latents.tile((num_videos_per_prompt, 1, 1, 1)) - return image_latents def _get_add_time_ids( @@ -331,6 +368,7 @@ def __call__( width: int = 1024, num_frames: Optional[int] = None, num_inference_steps: int = 25, + sigmas: Optional[List[float]] = None, min_guidance_scale: float = 1.0, max_guidance_scale: float = 3.0, fps: int = 7, @@ -361,6 +399,10 @@ def __call__( num_inference_steps (`int`, *optional*, defaults to 25): The number of denoising steps. More denoising steps usually lead to a higher quality video at the expense of slower inference. This parameter is modulated by `strength`. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. min_guidance_scale (`float`, *optional*, defaults to 1.0): The minimum guidance scale. Used for the classifier free guidance with first frame. max_guidance_scale (`float`, *optional*, defaults to 3.0): @@ -437,7 +479,7 @@ def __call__( fps = fps - 1 # 4. Encode input image using VAE - image = self.image_processor.preprocess(image, height=height, width=width) + image = self.video_processor.preprocess(image, height=height, width=width) noise = randn_tensor(image.shape, generator=generator, dtype=image.dtype) image = image + noise_aug_strength * noise @@ -472,8 +514,7 @@ def __call__( ) # 6. Prepare timesteps - self.scheduler.set_timesteps(num_inference_steps) - timesteps = self.scheduler.timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, None, sigmas) # 7. Prepare latent variables num_channels_latents = self.unet.config.in_channels @@ -543,7 +584,7 @@ def __call__( if needs_upcasting: self.vae.to(dtype=ms.float16) frames = self.decode_latents(latents, num_frames, decode_chunk_size) - frames = tensor2vid(frames, self.image_processor, output_type=output_type) + frames = self.video_processor.postprocess_video(video=frames, output_type=output_type) else: frames = latents @@ -616,7 +657,7 @@ def _filter2d(input, kernel): height, width = tmp_kernel.shape[-2:] - padding_shape: list[int] = _compute_padding([height, width]) + padding_shape: List[int] = _compute_padding([height, width]) input = ops.pad(input, padding_shape, mode="reflect") # kernel and input tensor reshape to align element-wise or batch-wise params diff --git a/mindone/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py b/mindone/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py index a459d8d75b..daaa808593 100644 --- a/mindone/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +++ b/mindone/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py @@ -114,6 +114,7 @@ def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, **kwargs, ): """ @@ -124,17 +125,21 @@ def retrieve_timesteps( scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): - The number of diffusion steps used when generating samples with a pre-trained model. If used, - `timesteps` must be `None`. + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. timesteps (`List[int]`, *optional*): - Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default - timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` - must be `None`. + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[ms.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: @@ -145,6 +150,16 @@ def retrieve_timesteps( scheduler.set_timesteps(timesteps=timesteps, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, **kwargs) timesteps = scheduler.timesteps @@ -433,9 +448,10 @@ def encode_prompt( negative_prompt_embeds = negative_prompt_embeds.tile((1, num_images_per_prompt, 1)) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - if isinstance(self, LoraLoaderMixin): - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder is not None: + if isinstance(self, LoraLoaderMixin): + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) return prompt_embeds, negative_prompt_embeds @@ -553,7 +569,12 @@ def check_inputs( # Copied from mindone.diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, generator, latents=None): - shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" @@ -599,20 +620,22 @@ def _default_height_width(self, height, width, image): return height, width # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding - def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=ms.float32): + def get_guidance_scale_embedding( + self, w: ms.Tensor, embedding_dim: int = 512, dtype: ms.Type = ms.float32 + ) -> ms.Tensor: """ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 Args: - timesteps (`ms.Tensor`): - generate embedding vectors at these timesteps + w (`ms.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. embedding_dim (`int`, *optional*, defaults to 512): - dimension of the embeddings to generate - dtype: - data type of the generated embeddings + Dimension of the embeddings to generate. + dtype (`ms.dtype`, *optional*, defaults to `ms.float32`): + Data type of the generated embeddings. Returns: - `ms.Tensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)` + `ms.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. """ assert len(w.shape) == 1 w = w * 1000.0 @@ -646,6 +669,7 @@ def __call__( width: Optional[int] = None, num_inference_steps: int = 50, timesteps: List[int] = None, + sigmas: List[float] = None, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, @@ -671,7 +695,7 @@ def __call__( instead. image (`ms.Tensor`, `PIL.Image.Image`, `List[ms.Tensor]` or `List[PIL.Image.Image]` or `List[List[PIL.Image.Image]]`): The Adapter input condition. Adapter uses this input condition to generate guidance to Unet. If the - type is specified as `Torch.FloatTensor`, it is passed to Adapter as is. PIL.Image.Image` can also be + type is specified as `ms.FloatTensor`, it is passed to Adapter as is. PIL.Image.Image` can also be accepted as an image. The control image is automatically resized to fit the output image. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. @@ -684,6 +708,10 @@ def __call__( Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen @@ -790,7 +818,7 @@ def __call__( prompt_embeds = ops.cat([negative_prompt_embeds, prompt_embeds]) # 4. Prepare timesteps - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, timesteps) + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, timesteps, sigmas) # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels diff --git a/mindone/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py b/mindone/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py index f508edc5de..37bed95c9e 100644 --- a/mindone/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +++ b/mindone/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py @@ -105,8 +105,8 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 """ - std_text = noise_pred_text.std(axis=list(range(1, noise_pred_text.ndim)), keepdims=True) - std_cfg = noise_cfg.std(axis=list(range(1, noise_cfg.ndim)), keepdims=True) + std_text = noise_pred_text.std(axis=tuple(range(1, noise_pred_text.ndim)), keepdims=True) + std_cfg = noise_cfg.std(axis=tuple(range(1, noise_cfg.ndim)), keepdims=True) # rescale the results from guidance (fixes overexposure) noise_pred_rescaled = noise_cfg * (std_text / std_cfg) # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images @@ -119,6 +119,7 @@ def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, **kwargs, ): """ @@ -129,17 +130,21 @@ def retrieve_timesteps( scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): - The number of diffusion steps used when generating samples with a pre-trained model. If used, - `timesteps` must be `None`. + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. timesteps (`List[int]`, *optional*): - Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default - timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` - must be `None`. + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[ms.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: @@ -150,6 +155,16 @@ def retrieve_timesteps( scheduler.set_timesteps(timesteps=timesteps, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, **kwargs) timesteps = scheduler.timesteps @@ -475,11 +490,9 @@ def encode_image(self, image, num_images_per_prompt, output_hidden_states=None): image = image.to(dtype=dtype) if output_hidden_states: - image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True)[2][-2] image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_enc_hidden_states = self.image_encoder( - ops.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] + uncond_image_enc_hidden_states = self.image_encoder(ops.zeros_like(image), output_hidden_states=True)[2][-2] uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( num_images_per_prompt, dim=0 ) @@ -657,7 +670,12 @@ def check_inputs( # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, generator, latents=None): - shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" @@ -725,20 +743,22 @@ def _default_height_width(self, height, width, image): return height, width # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding - def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=ms.float32): + def get_guidance_scale_embedding( + self, w: ms.Tensor, embedding_dim: int = 512, dtype: ms.Type = ms.float32 + ) -> ms.Tensor: """ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 Args: - timesteps (`ms.Tensor`): - generate embedding vectors at these timesteps + w (`ms.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. embedding_dim (`int`, *optional*, defaults to 512): - dimension of the embeddings to generate - dtype: - data type of the generated embeddings + Dimension of the embeddings to generate. + dtype (`ms.dtype`, *optional*, defaults to `ms.float32`): + Data type of the generated embeddings. Returns: - `ms.Tensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)` + `ms.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. """ assert len(w.shape) == 1 w = w * 1000.0 @@ -772,6 +792,7 @@ def __call__( height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, + sigmas: List[float] = None, timesteps: List[int] = None, denoising_end: Optional[float] = None, guidance_scale: float = 5.0, @@ -815,7 +836,7 @@ def __call__( used in both text-encoders image (`ms.Tensor`, `PIL.Image.Image`, `List[ms.Tensor]` or `List[PIL.Image.Image]` or `List[List[PIL.Image.Image]]`): The Adapter input condition. Adapter uses this input condition to generate guidance to Unet. If the - type is specified as `Torch.FloatTensor`, it is passed to Adapter as is. PIL.Image.Image` can also be + type is specified as `ms.FloatTensor`, it is passed to Adapter as is. PIL.Image.Image` can also be accepted as an image. The control image is automatically resized to fit the output image. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. Anything below 512 pixels won't work well for @@ -832,6 +853,10 @@ def __call__( Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. denoising_end (`float`, *optional*): When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be completed before it is intentionally prematurely terminated. As a result, the returned sample will @@ -1028,7 +1053,7 @@ def __call__( ) # 4. Prepare timesteps - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, timesteps) + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, timesteps, sigmas) # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels diff --git a/mindone/diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py b/mindone/diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py index f6efbf4cf5..720b0804b0 100644 --- a/mindone/diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +++ b/mindone/diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py @@ -19,8 +19,8 @@ def construct(self, x): class TimestepBlock(nn.Cell): def __init__(self, c, c_timestep): super().__init__() - linear_cls = nn.Dense - self.mapper = linear_cls(c_timestep, c * 2) + + self.mapper = nn.Dense(c_timestep, c * 2) def construct(self, x, t): a, b = self.mapper(t)[:, :, None, None].chunk(2, axis=1) @@ -31,15 +31,12 @@ class ResBlock(nn.Cell): def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0): super().__init__() - conv_cls = nn.Conv2d - linear_cls = nn.Dense - - self.depthwise = conv_cls( + self.depthwise = nn.Conv2d( c + c_skip, c, kernel_size=kernel_size, padding=kernel_size // 2, group=c, has_bias=True, pad_mode="pad" ) self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6) self.channelwise = nn.SequentialCell( - linear_cls(c, c * 4), nn.GELU(), GlobalResponseNorm(c * 4), nn.Dropout(p=dropout), linear_cls(c * 4, c) + nn.Dense(c, c * 4), nn.GELU(), GlobalResponseNorm(c * 4), nn.Dropout(p=dropout), nn.Dense(c * 4, c) ) def construct(self, x, x_skip=None): @@ -68,12 +65,10 @@ class AttnBlock(nn.Cell): def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0): super().__init__() - linear_cls = nn.Dense - self.self_attn = self_attn self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6) self.attention = Attention(query_dim=c, heads=nhead, dim_head=c // nhead, dropout=dropout, bias=True) - self.kv_mapper = nn.SequentialCell(SiLU(), linear_cls(c_cond, c)) + self.kv_mapper = nn.SequentialCell(SiLU(), nn.Dense(c_cond, c)) def construct(self, x, kv): kv = self.kv_mapper(kv) diff --git a/mindone/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py b/mindone/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py index 88a3589668..cc76b63741 100644 --- a/mindone/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +++ b/mindone/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py @@ -38,15 +38,13 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft @register_to_config def __init__(self, c_in=16, c=1280, c_cond=1024, c_r=64, depth=16, nhead=16, dropout=0.1): super().__init__() - conv_cls = nn.Conv2d - linear_cls = nn.Dense self.c_r = c_r - self.projection = conv_cls(c_in, c, kernel_size=1, has_bias=True, pad_mode="valid") + self.projection = nn.Conv2d(c_in, c, kernel_size=1, has_bias=True, pad_mode="valid") self.cond_mapper = nn.SequentialCell( - linear_cls(c_cond, c), + nn.Dense(c_cond, c), nn.LeakyReLU(0.2), - linear_cls(c, c), + nn.Dense(c, c), ) blocks = [] @@ -57,7 +55,7 @@ def __init__(self, c_in=16, c=1280, c_cond=1024, c_r=64, depth=16, nhead=16, dro self.blocks = nn.CellList(blocks) self.out = nn.SequentialCell( WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6), - conv_cls(c, c_in * 2, kernel_size=1, has_bias=True, pad_mode="valid"), + nn.Conv2d(c, c_in * 2, kernel_size=1, has_bias=True, pad_mode="valid"), ) self._gradient_checkpointing = False diff --git a/mindone/diffusers/schedulers/__init__.py b/mindone/diffusers/schedulers/__init__.py index f05872377f..02f238a60e 100644 --- a/mindone/diffusers/schedulers/__init__.py +++ b/mindone/diffusers/schedulers/__init__.py @@ -47,7 +47,7 @@ "scheduling_unclip": ["UnCLIPScheduler"], "scheduling_unipc_multistep": ["UniPCMultistepScheduler"], "scheduling_vq_diffusion": ["VQDiffusionScheduler"], - "scheduling_utils": ["KarrasDiffusionSchedulers", "SchedulerMixin"], + "scheduling_utils": ["AysSchedules", "KarrasDiffusionSchedulers", "SchedulerMixin"], } if TYPE_CHECKING: @@ -81,7 +81,7 @@ from .scheduling_tcd import TCDScheduler from .scheduling_unclip import UnCLIPScheduler from .scheduling_unipc_multistep import UniPCMultistepScheduler - from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin + from .scheduling_utils import AysSchedules, KarrasDiffusionSchedulers, SchedulerMixin from .scheduling_vq_diffusion import VQDiffusionScheduler else: diff --git a/mindone/diffusers/schedulers/scheduling_ddim.py b/mindone/diffusers/schedulers/scheduling_ddim.py index 02652bec34..7c6a1c3b37 100644 --- a/mindone/diffusers/schedulers/scheduling_ddim.py +++ b/mindone/diffusers/schedulers/scheduling_ddim.py @@ -83,7 +83,7 @@ def alpha_bar_fn(t): return math.exp(t * -12.0) else: - raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") + raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): @@ -214,7 +214,7 @@ def __init__( # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) else: - raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") # Rescale for zero SNR if rescale_betas_zero_snr: diff --git a/mindone/diffusers/schedulers/scheduling_ddim_inverse.py b/mindone/diffusers/schedulers/scheduling_ddim_inverse.py index 7fd54bdf19..b9b5aca33b 100644 --- a/mindone/diffusers/schedulers/scheduling_ddim_inverse.py +++ b/mindone/diffusers/schedulers/scheduling_ddim_inverse.py @@ -211,7 +211,7 @@ def __init__( # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) else: - raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") # Rescale for zero SNR if rescale_betas_zero_snr: diff --git a/mindone/diffusers/schedulers/scheduling_ddim_parallel.py b/mindone/diffusers/schedulers/scheduling_ddim_parallel.py index d922917ec2..225e966eb1 100644 --- a/mindone/diffusers/schedulers/scheduling_ddim_parallel.py +++ b/mindone/diffusers/schedulers/scheduling_ddim_parallel.py @@ -222,7 +222,7 @@ def __init__( # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) else: - raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") # Rescale for zero SNR if rescale_betas_zero_snr: diff --git a/mindone/diffusers/schedulers/scheduling_ddpm.py b/mindone/diffusers/schedulers/scheduling_ddpm.py index 00cb550016..a38518cec9 100644 --- a/mindone/diffusers/schedulers/scheduling_ddpm.py +++ b/mindone/diffusers/schedulers/scheduling_ddpm.py @@ -80,7 +80,7 @@ def alpha_bar_fn(t): return math.exp(t * -12.0) else: - raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") + raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): @@ -214,7 +214,7 @@ def __init__( betas = ms.tensor(np.linspace(-6, 6, num_train_timesteps), dtype=ms.float32) self.betas = ops.sigmoid(betas) * (beta_end - beta_start) + beta_start else: - raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") # Rescale for zero SNR if rescale_betas_zero_snr: diff --git a/mindone/diffusers/schedulers/scheduling_ddpm_parallel.py b/mindone/diffusers/schedulers/scheduling_ddpm_parallel.py index f893016efe..8237f78679 100644 --- a/mindone/diffusers/schedulers/scheduling_ddpm_parallel.py +++ b/mindone/diffusers/schedulers/scheduling_ddpm_parallel.py @@ -223,7 +223,7 @@ def __init__( betas = ms.tensor(np.linspace(-6, 6, num_train_timesteps)) self.betas = ops.sigmoid(betas) * (beta_end - beta_start) + beta_start else: - raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") # Rescale for zero SNR if rescale_betas_zero_snr: diff --git a/mindone/diffusers/schedulers/scheduling_deis_multistep.py b/mindone/diffusers/schedulers/scheduling_deis_multistep.py index 9f5059e2fd..99fc3375dd 100644 --- a/mindone/diffusers/schedulers/scheduling_deis_multistep.py +++ b/mindone/diffusers/schedulers/scheduling_deis_multistep.py @@ -156,7 +156,7 @@ def __init__( # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) else: - raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas self.alphas_cumprod = ops.cumprod(self.alphas, dim=0) @@ -174,13 +174,13 @@ def __init__( if algorithm_type in ["dpmsolver", "dpmsolver++"]: self.register_to_config(algorithm_type="deis") else: - raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}") + raise NotImplementedError(f"{algorithm_type} is not implemented for {self.__class__}") if solver_type not in ["logrho"]: if solver_type in ["midpoint", "heun", "bh1", "bh2"]: self.register_to_config(solver_type="logrho") else: - raise NotImplementedError(f"solver type {solver_type} does is not implemented for {self.__class__}") + raise NotImplementedError(f"solver type {solver_type} is not implemented for {self.__class__}") # setable values self.num_inference_steps = None diff --git a/mindone/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/mindone/diffusers/schedulers/scheduling_dpmsolver_multistep.py index 55a8e399f5..6c59561ec8 100644 --- a/mindone/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/mindone/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -236,7 +236,7 @@ def __init__( # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) else: - raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") if rescale_betas_zero_snr: self.betas = rescale_zero_terminal_snr(self.betas) @@ -263,13 +263,13 @@ def __init__( if algorithm_type == "deis": self.register_to_config(algorithm_type="dpmsolver++") else: - raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}") + raise NotImplementedError(f"{algorithm_type} is not implemented for {self.__class__}") if solver_type not in ["midpoint", "heun"]: if solver_type in ["logrho", "bh1", "bh2"]: self.register_to_config(solver_type="midpoint") else: - raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") + raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}") if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"] and final_sigmas_type == "zero": raise ValueError( @@ -309,42 +309,63 @@ def set_begin_index(self, begin_index: int = 0): """ self._begin_index = begin_index - def set_timesteps(self, num_inference_steps: int = None): + def set_timesteps(self, num_inference_steps: int = None, timesteps: Optional[List[int]] = None): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). Args: num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. + timesteps (`List[int]`, *optional*): + Custom timesteps used to support arbitrary timesteps schedule. If `None`, timesteps will be generated + based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps` and `sigmas` + must be `None`, and `timestep_spacing` attribute will be ignored. """ - # Clipping the minimum of all lambda(t) for numerical stability. - # This is critical for cosine (squaredcos_cap_v2) noise schedule. - clipped_idx = ms.tensor( - np.searchsorted(ops.flip(self.lambda_t, [0]).asnumpy(), self.config.lambda_min_clipped), dtype=ms.int64 - ) - last_timestep = ((self.config.num_train_timesteps - clipped_idx).asnumpy()).item() - - # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 - if self.config.timestep_spacing == "linspace": - timesteps = ( - np.linspace(0, last_timestep - 1, num_inference_steps + 1).round()[::-1][:-1].copy().astype(np.int64) - ) - elif self.config.timestep_spacing == "leading": - step_ratio = last_timestep // (num_inference_steps + 1) - # creates integer timesteps by multiplying by ratio - # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64) - timesteps += self.config.steps_offset - elif self.config.timestep_spacing == "trailing": - step_ratio = self.config.num_train_timesteps / num_inference_steps - # creates integer timesteps by multiplying by ratio - # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = np.arange(last_timestep, 0, -step_ratio).round().copy().astype(np.int64) - timesteps -= 1 + if num_inference_steps is None and timesteps is None: + raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps`.") + if num_inference_steps is not None and timesteps is not None: + raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.") + if timesteps is not None and self.config.use_karras_sigmas: + raise ValueError("Cannot use `timesteps` with `config.use_karras_sigmas = True`") + if timesteps is not None and self.config.use_lu_lambdas: + raise ValueError("Cannot use `timesteps` with `config.use_lu_lambdas = True`") + + if timesteps is not None: + timesteps = np.array(timesteps).astype(np.int64) else: - raise ValueError( - f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + # Clipping the minimum of all lambda(t) for numerical stability. + # This is critical for cosine (squaredcos_cap_v2) noise schedule. + clipped_idx = ms.tensor( + np.searchsorted(ops.flip(self.lambda_t, [0]).asnumpy(), self.config.lambda_min_clipped), dtype=ms.int64 ) + last_timestep = ((self.config.num_train_timesteps - clipped_idx).asnumpy()).item() + + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = ( + np.linspace(0, last_timestep - 1, num_inference_steps + 1) + .round()[::-1][:-1] + .copy() + .astype(np.int64) + ) + elif self.config.timestep_spacing == "leading": + step_ratio = last_timestep // (num_inference_steps + 1) + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = ( + (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64) + ) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = np.arange(last_timestep, 0, -step_ratio).round().copy().astype(np.int64) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) sigmas = (((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5).asnumpy() log_sigmas = np.log(sigmas) diff --git a/mindone/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py b/mindone/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py index 1c1f41ebb2..5dd52c6655 100644 --- a/mindone/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +++ b/mindone/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py @@ -185,7 +185,7 @@ def __init__( # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) else: - raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas self.alphas_cumprod = ops.cumprod(self.alphas, dim=0) @@ -203,13 +203,13 @@ def __init__( if algorithm_type == "deis": self.register_to_config(algorithm_type="dpmsolver++") else: - raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}") + raise NotImplementedError(f"{algorithm_type} is not implemented for {self.__class__}") if solver_type not in ["midpoint", "heun"]: if solver_type in ["logrho", "bh1", "bh2"]: self.register_to_config(solver_type="midpoint") else: - raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") + raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}") # setable values self.num_inference_steps = None diff --git a/mindone/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/mindone/diffusers/schedulers/scheduling_dpmsolver_singlestep.py index 86318c64fc..e2a8557e22 100644 --- a/mindone/diffusers/schedulers/scheduling_dpmsolver_singlestep.py +++ b/mindone/diffusers/schedulers/scheduling_dpmsolver_singlestep.py @@ -109,11 +109,11 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `algorithm_type="dpmsolver++"`. algorithm_type (`str`, defaults to `dpmsolver++`): - Algorithm type for the solver; can be `dpmsolver` or `dpmsolver++`. The - `dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927) - paper, and the `dpmsolver++` type implements the algorithms in the - [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or - `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion. + Algorithm type for the solver; can be `dpmsolver` or `dpmsolver++`. The `dpmsolver` type implements the + algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927) paper, and the `dpmsolver++` type + implements the algorithms in the [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is + recommended to use `dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided sampling like in + Stable Diffusion. solver_type (`str`, defaults to `midpoint`): Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers. @@ -124,8 +124,8 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, the sigmas are determined according to a sequence of noise levels {σi}. final_sigmas_type (`str`, *optional*, defaults to `"zero"`): - The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final sigma - is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. + The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final + sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. lambda_min_clipped (`float`, defaults to `-inf`): Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the cosine (`squaredcos_cap_v2`) noise schedule. @@ -178,7 +178,7 @@ def __init__( # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) else: - raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas self.alphas_cumprod = ops.cumprod(self.alphas, dim=0) @@ -196,12 +196,12 @@ def __init__( if algorithm_type == "deis": self.register_to_config(algorithm_type="dpmsolver++") else: - raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}") + raise NotImplementedError(f"{algorithm_type} is not implemented for {self.__class__}") if solver_type not in ["midpoint", "heun"]: if solver_type in ["logrho", "bh1", "bh2"]: self.register_to_config(solver_type="midpoint") else: - raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") + raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}") if algorithm_type != "dpmsolver++" and final_sigmas_type == "zero": raise ValueError( @@ -279,24 +279,40 @@ def set_begin_index(self, begin_index: int = 0): """ self._begin_index = begin_index - def set_timesteps(self, num_inference_steps: int): + def set_timesteps(self, num_inference_steps: int = None, timesteps: Optional[List[int]] = None): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). Args: num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. + timesteps (`List[int]`, *optional*): + Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default + timestep spacing strategy of equal spacing between timesteps schedule is used. If `timesteps` is + passed, `num_inference_steps` must be `None`. """ + if num_inference_steps is None and timesteps is None: + raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps`.") + if num_inference_steps is not None and timesteps is not None: + raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps`.") + if timesteps is not None and self.config.use_karras_sigmas: + raise ValueError("Cannot use `timesteps` when `config.use_karras_sigmas=True`.") + + num_inference_steps = num_inference_steps or len(timesteps) self.num_inference_steps = num_inference_steps - # Clipping the minimum of all lambda(t) for numerical stability. - # This is critical for cosine (squaredcos_cap_v2) noise schedule. - clipped_idx = np.searchsorted(ops.flip(self.lambda_t, [0]).asnumpy(), self.config.lambda_min_clipped) - timesteps = ( - np.linspace(0, self.config.num_train_timesteps - 1 - clipped_idx, num_inference_steps + 1) - .round()[::-1][:-1] - .copy() - .astype(np.int64) - ) + + if timesteps is not None: + timesteps = np.array(timesteps).astype(np.int64) + else: + # Clipping the minimum of all lambda(t) for numerical stability. + # This is critical for cosine (squaredcos_cap_v2) noise schedule. + clipped_idx = np.searchsorted(ops.flip(self.lambda_t, [0]).asnumpy(), self.config.lambda_min_clipped) + timesteps = ( + np.linspace(0, self.config.num_train_timesteps - 1 - clipped_idx, num_inference_steps + 1) + .round()[::-1][:-1] + .copy() + .astype(np.int64) + ) sigmas = (((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5).asnumpy() if self.config.use_karras_sigmas: diff --git a/mindone/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py b/mindone/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py index 9059238d36..2792d35f89 100644 --- a/mindone/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +++ b/mindone/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py @@ -14,6 +14,7 @@ # DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver and https://github.com/NVlabs/edm +import math from typing import List, Optional, Tuple, Union import numpy as np @@ -46,6 +47,10 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): range is [0.2, 80.0]. sigma_data (`float`, *optional*, defaults to 0.5): The standard deviation of the data distribution. This is set to 0.5 in the EDM paper [1]. + sigma_schedule (`str`, *optional*, defaults to `karras`): + Sigma schedule to compute the `sigmas`. By default, we the schedule introduced in the EDM paper + (https://arxiv.org/abs/2206.00364). Other acceptable value is "exponential". The exponential schedule was + incorporated in this model: https://huggingface.co/stabilityai/cosxl. num_train_timesteps (`int`, defaults to 1000): The number of diffusion steps to train the model. solver_order (`int`, defaults to 2): @@ -64,10 +69,9 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `algorithm_type="dpmsolver++"`. algorithm_type (`str`, defaults to `dpmsolver++`): - Algorithm type for the solver; can be `dpmsolver++` or `sde-dpmsolver++`. The - `dpmsolver++` type implements the algorithms in the - [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or - `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion. + Algorithm type for the solver; can be `dpmsolver++` or `sde-dpmsolver++`. The `dpmsolver++` type implements + the algorithms in the [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to + use `dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion. solver_type (`str`, defaults to `midpoint`): Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers. @@ -79,8 +83,8 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference steps, but sometimes may result in blurring. final_sigmas_type (`str`, defaults to `"zero"`): - The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final sigma - is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. + The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final + sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. """ _compatibles = [] @@ -92,6 +96,7 @@ def __init__( sigma_min: float = 0.002, sigma_max: float = 80.0, sigma_data: float = 0.5, + sigma_schedule: str = "karras", num_train_timesteps: int = 1000, prediction_type: str = "epsilon", rho: float = 7.0, @@ -116,7 +121,7 @@ def __init__( if solver_type in ["logrho", "bh1", "bh2"]: self.register_to_config(solver_type="midpoint") else: - raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") + raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}") if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"] and final_sigmas_type == "zero": raise ValueError( @@ -124,7 +129,11 @@ def __init__( ) ramp = ms.tensor(np.linspace(0, 1, num_train_timesteps), dtype=ms.float32) - sigmas = self._compute_sigmas(ramp) + if sigma_schedule == "karras": + sigmas = self._compute_karras_sigmas(ramp) + elif sigma_schedule == "exponential": + sigmas = self._compute_exponential_sigmas(ramp) + self.timesteps = self.precondition_noise(sigmas) self.sigmas = self.sigmas = ops.cat([sigmas, ops.zeros(1, dtype=sigmas.dtype)]) @@ -144,7 +153,7 @@ def init_noise_sigma(self): @property def step_index(self): """ - The index counter for current timestep. It will increae 1 after each scheduler step. + The index counter for current timestep. It will increase 1 after each scheduler step. """ return self._step_index @@ -233,10 +242,13 @@ def set_timesteps(self, num_inference_steps: int = None): self.num_inference_steps = num_inference_steps - ramp = np.linspace(0, 1, self.num_inference_steps) - sigmas = self._compute_sigmas(ramp) + ramp = ms.tensor(np.linspace(0, 1, self.num_inference_steps)) + if self.config.sigma_schedule == "karras": + sigmas = self._compute_karras_sigmas(ramp) + elif self.config.sigma_schedule == "exponential": + sigmas = self._compute_exponential_sigmas(ramp) - sigmas = ms.tensor(sigmas, dtype=ms.float32) + sigmas = sigmas.to(ms.float32) self.timesteps = self.precondition_noise(sigmas) if self.config.final_sigmas_type == "sigma_min": @@ -259,8 +271,8 @@ def set_timesteps(self, num_inference_steps: int = None): self._step_index = None self._begin_index = None - # Taken from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17 - def _compute_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> ms.Tensor: + # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_karras_sigmas + def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> ms.Tensor: """Constructs the noise schedule of Karras et al. (2022).""" sigma_min = sigma_min or self.config.sigma_min @@ -272,6 +284,17 @@ def _compute_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> ms.Tensor: sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho return sigmas + # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_exponential_sigmas + def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> ms.Tensor: + """Implementation closely follows k-diffusion. + + https://github.com/crowsonkb/k-diffusion/blob/6ab5146d4a5ef63901326489f31f1d8e7dd36b48/k_diffusion/sampling.py#L26 + """ + sigma_min = sigma_min or self.config.sigma_min + sigma_max = sigma_max or self.config.sigma_max + sigmas = ms.tensor(np.linspace(math.log(sigma_min), math.log(sigma_max), len(ramp))).exp().flip((0,)) + return sigmas + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: ms.Tensor) -> ms.Tensor: """ diff --git a/mindone/diffusers/schedulers/scheduling_edm_euler.py b/mindone/diffusers/schedulers/scheduling_edm_euler.py index 1848066b08..f359339bd1 100644 --- a/mindone/diffusers/schedulers/scheduling_edm_euler.py +++ b/mindone/diffusers/schedulers/scheduling_edm_euler.py @@ -208,13 +208,13 @@ def set_timesteps(self, num_inference_steps: int): """ self.num_inference_steps = num_inference_steps - ramp = np.linspace(0, 1, self.num_inference_steps) + ramp = ms.tensor(np.linspace(0, 1, self.num_inference_steps)) if self.config.sigma_schedule == "karras": sigmas = self._compute_karras_sigmas(ramp) elif self.config.sigma_schedule == "exponential": sigmas = self._compute_exponential_sigmas(ramp) - sigmas = ms.tensor(sigmas, dtype=ms.float32) + sigmas = sigmas.to(ms.float32) self.timesteps = self.precondition_noise(sigmas) self.sigmas = ops.cat([sigmas, ops.zeros(1)]) @@ -241,7 +241,7 @@ def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> m """ sigma_min = sigma_min or self.config.sigma_min sigma_max = sigma_max or self.config.sigma_max - sigmas = ms.tensor(np.linspace(math.log(sigma_min), math.log(sigma_max), len(ramp)).exp().flip(0)) + sigmas = ms.tensor(np.linspace(math.log(sigma_min), math.log(sigma_max), len(ramp)).exp().flip((0,))) return sigmas # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep diff --git a/mindone/diffusers/schedulers/scheduling_euler_ancestral_discrete.py b/mindone/diffusers/schedulers/scheduling_euler_ancestral_discrete.py index 85491ccf5a..15f97154b0 100644 --- a/mindone/diffusers/schedulers/scheduling_euler_ancestral_discrete.py +++ b/mindone/diffusers/schedulers/scheduling_euler_ancestral_discrete.py @@ -193,7 +193,7 @@ def __init__( # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) else: - raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") if rescale_betas_zero_snr: self.betas = rescale_zero_terminal_snr(self.betas) diff --git a/mindone/diffusers/schedulers/scheduling_euler_discrete.py b/mindone/diffusers/schedulers/scheduling_euler_discrete.py index c158f52701..6b9baa5b23 100644 --- a/mindone/diffusers/schedulers/scheduling_euler_discrete.py +++ b/mindone/diffusers/schedulers/scheduling_euler_discrete.py @@ -167,6 +167,9 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and dark samples instead of limiting it to samples with medium brightness. Loosely related to [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). + final_sigmas_type (`str`, defaults to `"zero"`): + The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final + sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -189,6 +192,7 @@ def __init__( timestep_type: str = "discrete", # can be "discrete" or "continuous" steps_offset: int = 0, rescale_betas_zero_snr: bool = False, + final_sigmas_type: str = "zero", # can be "zero" or "sigma_min" ): if trained_betas is not None: self.betas = ms.tensor(trained_betas, dtype=ms.float32) @@ -203,7 +207,7 @@ def __init__( # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) else: - raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") if rescale_betas_zero_snr: self.betas = rescale_zero_terminal_snr(self.betas) @@ -295,66 +299,125 @@ def scale_model_input(self, sample: ms.Tensor, timestep: Union[float, ms.Tensor] self.is_scale_input_called = True return sample - def set_timesteps(self, num_inference_steps: int): + def set_timesteps( + self, + num_inference_steps: int = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + ): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). Args: num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. + timesteps (`List[int]`, *optional*): + Custom timesteps used to support arbitrary timesteps schedule. If `None`, timesteps will be generated + based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps` and `sigmas` + must be `None`, and `timestep_spacing` attribute will be ignored. + sigmas (`List[float]`, *optional*): + Custom sigmas used to support arbitrary timesteps schedule schedule. If `None`, timesteps and sigmas + will be generated based on the relevant scheduler attributes. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`, and the timesteps will be generated based on the + custom sigmas schedule. """ - self.num_inference_steps = num_inference_steps - # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 - if self.config.timestep_spacing == "linspace": - timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[ - ::-1 - ].copy() - elif self.config.timestep_spacing == "leading": - step_ratio = self.config.num_train_timesteps // self.num_inference_steps - # creates integer timesteps by multiplying by ratio - # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32) - timesteps += self.config.steps_offset - elif self.config.timestep_spacing == "trailing": - step_ratio = self.config.num_train_timesteps / self.num_inference_steps - # creates integer timesteps by multiplying by ratio - # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32) - timesteps -= 1 - else: + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` should be set.") + if num_inference_steps is None and timesteps is None and sigmas is None: + raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps` or `sigmas.") + if num_inference_steps is not None and (timesteps is not None or sigmas is not None): + raise ValueError("Can only pass one of `num_inference_steps` or `timesteps` or `sigmas`.") + if timesteps is not None and self.config.use_karras_sigmas: + raise ValueError("Cannot set `timesteps` with `config.use_karras_sigmas = True`.") + if ( + timesteps is not None + and self.config.timestep_type == "continuous" + and self.config.prediction_type == "v_prediction" + ): raise ValueError( - f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + "Cannot set `timesteps` with `config.timestep_type = 'continuous'` and `config.prediction_type = 'v_prediction'`." ) - sigmas = (((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5).asnumpy() - log_sigmas = np.log(sigmas) + if num_inference_steps is None: + num_inference_steps = len(timesteps) if timesteps is not None else len(sigmas) - 1 + self.num_inference_steps = num_inference_steps - if self.config.interpolation_type == "linear": - sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) - elif self.config.interpolation_type == "log_linear": - sigmas = np.exp(np.linspace(np.log(sigmas[-1]), np.log(sigmas[0]), num_inference_steps + 1)) - else: - raise ValueError( - f"{self.config.interpolation_type} is not implemented. Please specify interpolation_type to either" - " 'linear' or 'log_linear'" - ) + if sigmas is not None: + log_sigmas = np.log(np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)) + sigmas = np.array(sigmas).astype(np.float32) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas[:-1]]) - if self.use_karras_sigmas: - sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) - timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) + else: + if timesteps is not None: + timesteps = np.array(timesteps).astype(np.float32) + else: + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = np.linspace( + 0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=np.float32 + )[::-1].copy() + elif self.config.timestep_spacing == "leading": + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = ( + (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32) + ) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) + + sigmas = (((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5).asnumpy() + log_sigmas = np.log(sigmas) + + if self.config.interpolation_type == "linear": + sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) + elif self.config.interpolation_type == "log_linear": + sigmas = np.exp(np.linspace(np.log(sigmas[-1]), np.log(sigmas[0]), num_inference_steps + 1)) + else: + raise ValueError( + f"{self.config.interpolation_type} is not implemented. Please specify interpolation_type to either" + " 'linear' or 'log_linear'" + ) + + if self.config.use_karras_sigmas: + sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) + + if self.config.final_sigmas_type == "sigma_min": + sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 + sigma_last = ( + sigma_last.asnumpy() + ) # Transform for numpy concatenate where Torch tensor could be concated with numpy array directly + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" + ) + + sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) sigmas = ms.Tensor(sigmas).to(dtype=ms.float32) # TODO: Support the full EDM scalings for all prediction types and timestep types if self.config.timestep_type == "continuous" and self.config.prediction_type == "v_prediction": - self.timesteps = ms.Tensor([0.25 * sigma.log().item() for sigma in sigmas]) + self.timesteps = ms.Tensor([0.25 * sigma.log().item() for sigma in sigmas[:-1]]) else: self.timesteps = ms.Tensor(timesteps.astype(np.float32)) - self.sigmas = ops.cat([sigmas, ops.zeros(1)]) self._step_index = None self._begin_index = None + self.sigmas = sigmas def _sigma_to_t(self, sigma, log_sigmas): # get log sigma @@ -561,5 +624,32 @@ def add_noise( noisy_samples = original_samples + noise * sigma return noisy_samples + def get_velocity(self, sample: ms.Tensor, noise: ms.Tensor, timesteps: ms.Tensor) -> ms.Tensor: + if isinstance(timesteps, int) or (isinstance(timesteps, ms.Tensor) and timesteps.dtype in [ms.int32, ms.int64]): + raise ValueError( + ( + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + " `EulerDiscreteScheduler.get_velocity()` is not supported. Make sure to pass" + " one of the `scheduler.timesteps` as a timestep." + ), + ) + + schedule_timesteps = self.timesteps + + step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + alphas_cumprod = self.alphas_cumprod.to(sample.dtype) + sqrt_alpha_prod = alphas_cumprod[step_indices] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(sample.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[step_indices]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample + return velocity + def __len__(self): return self.config.num_train_timesteps diff --git a/mindone/diffusers/schedulers/scheduling_heun_discrete.py b/mindone/diffusers/schedulers/scheduling_heun_discrete.py index 6993fddd29..5ecb443666 100644 --- a/mindone/diffusers/schedulers/scheduling_heun_discrete.py +++ b/mindone/diffusers/schedulers/scheduling_heun_discrete.py @@ -139,7 +139,7 @@ def __init__( elif beta_schedule == "exp": self.betas = betas_for_alpha_bar(num_train_timesteps, alpha_transform_type="exp") else: - raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas self.alphas_cumprod = ops.cumprod(self.alphas, dim=0) @@ -230,8 +230,9 @@ def scale_model_input( def set_timesteps( self, - num_inference_steps: int, + num_inference_steps: Optional[int] = None, num_train_timesteps: Optional[int] = None, + timesteps: Optional[List[int]] = None, ): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). @@ -239,30 +240,47 @@ def set_timesteps( Args: num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. + num_train_timesteps (`int`, *optional*): + The number of diffusion steps used when training the model. If `None`, the default + `num_train_timesteps` attribute is used. + timesteps (`List[int]`, *optional*): + Custom timesteps used to support arbitrary spacing between timesteps. If `None`, timesteps will be + generated based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps` + must be `None`, and `timestep_spacing` attribute will be ignored. """ + if num_inference_steps is None and timesteps is None: + raise ValueError("Must pass exactly one of `num_inference_steps` or `custom_timesteps`.") + if num_inference_steps is not None and timesteps is not None: + raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.") + if timesteps is not None and self.config.use_karras_sigmas: + raise ValueError("Cannot use `timesteps` with `config.use_karras_sigmas = True`") + + num_inference_steps = num_inference_steps or len(timesteps) self.num_inference_steps = num_inference_steps - num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps - # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 - if self.config.timestep_spacing == "linspace": - timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[::-1].copy() - elif self.config.timestep_spacing == "leading": - step_ratio = num_train_timesteps // self.num_inference_steps - # creates integer timesteps by multiplying by ratio - # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32) - timesteps += self.config.steps_offset - elif self.config.timestep_spacing == "trailing": - step_ratio = num_train_timesteps / self.num_inference_steps - # creates integer timesteps by multiplying by ratio - # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32) - timesteps -= 1 + if timesteps is not None: + timesteps = np.array(timesteps, dtype=np.float32) else: - raise ValueError( - f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." - ) + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[::-1].copy() + elif self.config.timestep_spacing == "leading": + step_ratio = num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) sigmas = (((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5).asnumpy() log_sigmas = np.log(sigmas) diff --git a/mindone/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py b/mindone/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py index db73785160..5e7e4668e3 100644 --- a/mindone/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +++ b/mindone/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py @@ -133,7 +133,7 @@ def __init__( # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) else: - raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas self.alphas_cumprod = ops.cumprod(self.alphas, dim=0) diff --git a/mindone/diffusers/schedulers/scheduling_k_dpm_2_discrete.py b/mindone/diffusers/schedulers/scheduling_k_dpm_2_discrete.py index 831b96aaf1..b9ac134ad5 100644 --- a/mindone/diffusers/schedulers/scheduling_k_dpm_2_discrete.py +++ b/mindone/diffusers/schedulers/scheduling_k_dpm_2_discrete.py @@ -132,7 +132,7 @@ def __init__( # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) else: - raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas self.alphas_cumprod = ops.cumprod(self.alphas, dim=0) diff --git a/mindone/diffusers/schedulers/scheduling_lcm.py b/mindone/diffusers/schedulers/scheduling_lcm.py index 2673385a8a..a3b59d9675 100644 --- a/mindone/diffusers/schedulers/scheduling_lcm.py +++ b/mindone/diffusers/schedulers/scheduling_lcm.py @@ -227,7 +227,7 @@ def __init__( # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) else: - raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") # Rescale for zero SNR if rescale_betas_zero_snr: diff --git a/mindone/diffusers/schedulers/scheduling_lms_discrete.py b/mindone/diffusers/schedulers/scheduling_lms_discrete.py index 86eeb19653..59d9ef4cd5 100644 --- a/mindone/diffusers/schedulers/scheduling_lms_discrete.py +++ b/mindone/diffusers/schedulers/scheduling_lms_discrete.py @@ -153,7 +153,7 @@ def __init__( # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) else: - raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas self.alphas_cumprod = ops.cumprod(self.alphas, dim=0) @@ -325,7 +325,7 @@ def _init_step_index(self, timestep): else: self._step_index = self._begin_index - # copied from diffusers.schedulers.scheduling_euler_discrete._sigma_to_t + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t def _sigma_to_t(self, sigma, log_sigmas): # get log sigma log_sigma = np.log(np.maximum(sigma, 1e-10)) @@ -349,7 +349,7 @@ def _sigma_to_t(self, sigma, log_sigmas): t = t.reshape(sigma.shape) return t - # copied from diffusers.schedulers.scheduling_euler_discrete._convert_to_karras + # copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras def _convert_to_karras(self, in_sigmas: ms.Tensor) -> ms.Tensor: """Constructs the noise schedule of Karras et al. (2022).""" diff --git a/mindone/diffusers/schedulers/scheduling_pndm.py b/mindone/diffusers/schedulers/scheduling_pndm.py index a1095f850d..dc14600299 100644 --- a/mindone/diffusers/schedulers/scheduling_pndm.py +++ b/mindone/diffusers/schedulers/scheduling_pndm.py @@ -61,7 +61,7 @@ def alpha_bar_fn(t): return math.exp(t * -12.0) else: - raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") + raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): @@ -139,7 +139,7 @@ def __init__( # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) else: - raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas self.alphas_cumprod = ops.cumprod(self.alphas, dim=0) diff --git a/mindone/diffusers/schedulers/scheduling_repaint.py b/mindone/diffusers/schedulers/scheduling_repaint.py index 9ca383a3c9..3a0525dd47 100644 --- a/mindone/diffusers/schedulers/scheduling_repaint.py +++ b/mindone/diffusers/schedulers/scheduling_repaint.py @@ -147,7 +147,7 @@ def __init__( betas = ms.tensor(np.linspace(-6, 6, num_train_timesteps)) self.betas = ops.sigmoid(betas) * (beta_end - beta_start) + beta_start else: - raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas self.alphas_cumprod = ops.cumprod(self.alphas, dim=0) diff --git a/mindone/diffusers/schedulers/scheduling_sasolver.py b/mindone/diffusers/schedulers/scheduling_sasolver.py index 49928073a3..78021dac0b 100644 --- a/mindone/diffusers/schedulers/scheduling_sasolver.py +++ b/mindone/diffusers/schedulers/scheduling_sasolver.py @@ -180,7 +180,7 @@ def __init__( # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) else: - raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas self.alphas_cumprod = ops.cumprod(self.alphas, dim=0) @@ -194,7 +194,7 @@ def __init__( self.init_noise_sigma = 1.0 if algorithm_type not in ["data_prediction", "noise_prediction"]: - raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}") + raise NotImplementedError(f"{algorithm_type} is not implemented for {self.__class__}") # setable values self.num_inference_steps = None diff --git a/mindone/diffusers/schedulers/scheduling_tcd.py b/mindone/diffusers/schedulers/scheduling_tcd.py index 14e45c52ca..bd68988a87 100644 --- a/mindone/diffusers/schedulers/scheduling_tcd.py +++ b/mindone/diffusers/schedulers/scheduling_tcd.py @@ -228,7 +228,7 @@ def __init__( # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) else: - raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") # Rescale for zero SNR if rescale_betas_zero_snr: diff --git a/mindone/diffusers/schedulers/scheduling_unipc_multistep.py b/mindone/diffusers/schedulers/scheduling_unipc_multistep.py index 66bbc21f46..d5c87c9d35 100644 --- a/mindone/diffusers/schedulers/scheduling_unipc_multistep.py +++ b/mindone/diffusers/schedulers/scheduling_unipc_multistep.py @@ -73,6 +73,43 @@ def alpha_bar_fn(t): return ms.tensor(betas, dtype=ms.float32) +# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr +def rescale_zero_terminal_snr(betas): + """ + Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) + + + Args: + betas (`ms.Tensor`): + the betas that the scheduler is being initialized with. + + Returns: + `ms.Tensor`: rescaled betas with zero terminal SNR + """ + # Convert betas to alphas_bar_sqrt + alphas = 1.0 - betas + alphas_cumprod = ops.cumprod(alphas, dim=0) + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].copy() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].copy() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod + alphas = ops.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + + return betas + + class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): """ `UniPCMultistepScheduler` is a training-free framework designed for the fast sampling of diffusion models. @@ -129,6 +166,13 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. steps_offset (`int`, defaults to 0): An offset added to the inference steps, as required by some model families. + final_sigmas_type (`str`, defaults to `"zero"`): + The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final + sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. + rescale_betas_zero_snr (`bool`, defaults to `False`): + Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and + dark samples instead of limiting it to samples with medium brightness. Loosely related to + [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -155,6 +199,8 @@ def __init__( use_karras_sigmas: Optional[bool] = False, timestep_spacing: str = "linspace", steps_offset: int = 0, + final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" + rescale_betas_zero_snr: bool = False, ): if trained_betas is not None: self.betas = ms.tensor(trained_betas, dtype=ms.float32) @@ -169,7 +215,15 @@ def __init__( # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) else: - raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") + + if rescale_betas_zero_snr: + self.betas = rescale_zero_terminal_snr(self.betas) + + if rescale_betas_zero_snr: + # Close to 0 without being 0 so first sigma is not inf + # FP16 smallest positive subnormal works well here + self.alphas_cumprod[-1] = 2**-24 self.alphas = 1.0 - self.betas self.alphas_cumprod = ops.cumprod(self.alphas, dim=0) @@ -186,7 +240,7 @@ def __init__( if solver_type in ["midpoint", "heun", "logrho"]: self.register_to_config(solver_type="bh2") else: - raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") + raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}") self.predict_x0 = predict_x0 # setable values @@ -266,10 +320,28 @@ def set_timesteps(self, num_inference_steps: int): sigmas = np.flip(sigmas).copy() sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() - sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) + if self.config.final_sigmas_type == "sigma_min": + sigma_last = sigmas[-1] + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" + ) + sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) else: sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) - sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 + if self.config.final_sigmas_type == "sigma_min": + sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 + sigma_last = ( + sigma_last.asnumpy() + ) # Transform for numpy concatenate where Torch tensor could be concated with numpy array directly + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" + ) sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) self.sigmas = ms.Tensor(sigmas) @@ -557,7 +629,7 @@ def multistep_uni_p_bh_update( if order == 2: rhos_p = ms.tensor([0.5], dtype=x.dtype) else: - rhos_p = ms.Tensor(np.linalg.solve(R[:-1, :-1].asnumpy(), b[:-1].asnumpy())) + rhos_p = ms.Tensor(np.linalg.solve(R[:-1, :-1].asnumpy(), b[:-1].asnumpy()), dtype=x.dtype) else: D1s = None @@ -694,7 +766,7 @@ def multistep_uni_c_bh_update( if order == 1: rhos_c = ms.tensor([0.5], dtype=x.dtype) else: - rhos_c = ms.Tensor(np.linalg.solve(R.asnumpy(), b.asnumpy())) + rhos_c = ms.Tensor(np.linalg.solve(R.asnumpy(), b.asnumpy()), dtype=x.dtype) if self.predict_x0: x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 diff --git a/mindone/diffusers/schedulers/scheduling_utils.py b/mindone/diffusers/schedulers/scheduling_utils.py index 124c69f1b6..3d73bec305 100644 --- a/mindone/diffusers/schedulers/scheduling_utils.py +++ b/mindone/diffusers/schedulers/scheduling_utils.py @@ -47,6 +47,15 @@ class KarrasDiffusionSchedulers(Enum): EDMEulerScheduler = 15 +AysSchedules = { + "StableDiffusionTimesteps": [999, 850, 736, 645, 545, 455, 343, 233, 124, 24], + "StableDiffusionSigmas": [14.615, 6.475, 3.861, 2.697, 1.886, 1.396, 0.963, 0.652, 0.399, 0.152, 0.0], + "StableDiffusionXLTimesteps": [999, 845, 730, 587, 443, 310, 193, 116, 53, 13], + "StableDiffusionXLSigmas": [14.615, 6.315, 3.771, 2.181, 1.342, 0.862, 0.555, 0.380, 0.234, 0.113, 0.0], + "StableDiffusionVideoSigmas": [700.00, 54.5, 15.886, 7.977, 4.248, 1.789, 0.981, 0.403, 0.173, 0.034, 0.0], +} + + @dataclass class SchedulerOutput(BaseOutput): """ @@ -112,9 +121,9 @@ def from_pretrained( force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. - resume_download (`bool`, *optional*, defaults to `False`): - Whether or not to resume downloading the model weights and configuration files. If set to `False`, any - incompletely downloaded files are deleted. + resume_download: + Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1 + of Diffusers. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. diff --git a/mindone/diffusers/training_utils.py b/mindone/diffusers/training_utils.py index f6b792c097..7ee6118d50 100644 --- a/mindone/diffusers/training_utils.py +++ b/mindone/diffusers/training_utils.py @@ -5,7 +5,7 @@ import time from abc import ABCMeta, abstractmethod from multiprocessing import Process, Queue -from typing import Any, Dict, Iterable, List, Optional, Union +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import numpy as np from tqdm.auto import tqdm @@ -17,6 +17,8 @@ from mindone.diffusers._peft import set_peft_model_state_dict +from .models import UNet2DConditionModel +from .schedulers import SchedulerMixin from .utils import convert_state_dict_to_diffusers, convert_state_dict_to_peft logger = logging.getLogger(__name__) @@ -48,6 +50,58 @@ def compute_snr(noise_scheduler, timesteps): return snr +def compute_dream_and_update_latents( + unet: UNet2DConditionModel, + noise_scheduler: SchedulerMixin, + timesteps: ms.Tensor, + noise: ms.Tensor, + noisy_latents: ms.Tensor, + target: ms.Tensor, + encoder_hidden_states: ms.Tensor, + dream_detail_preservation: float = 1.0, +) -> Tuple[Optional[ms.Tensor], Optional[ms.Tensor]]: + """ + Implements "DREAM (Diffusion Rectification and Estimation-Adaptive Models)" from http://arxiv.org/abs/2312.00210. + DREAM helps align training with sampling to help training be more efficient and accurate at the cost of an extra + forward step without gradients. + + Args: + `unet`: The state unet to use to make a prediction. + `noise_scheduler`: The noise scheduler used to add noise for the given timestep. + `timesteps`: The timesteps for the noise_scheduler to user. + `noise`: A tensor of noise in the shape of noisy_latents. + `noisy_latents`: Previously noise latents from the training loop. + `target`: The ground-truth tensor to predict after eps is removed. + `encoder_hidden_states`: Text embeddings from the text model. + `dream_detail_preservation`: A float value that indicates detail preservation level. + See reference. + + Returns: + `tuple[ms.Tensor, ms.Tensor]`: Adjusted noisy_latents and target. + """ + alphas_cumprod = noise_scheduler.alphas_cumprod[timesteps, None, None, None] + sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 + + # The paper uses lambda = sqrt(1 - alpha) ** p, with p = 1 in their experiments. + dream_lambda = sqrt_one_minus_alphas_cumprod**dream_detail_preservation + + pred = unet(noisy_latents, timesteps, encoder_hidden_states)[0] + + _noisy_latents, _target = (None, None) + if noise_scheduler.config["prediction_type"] == "epsilon": + predicted_noise = pred + delta_noise = ops.stop_gradient(noise - predicted_noise) + delta_noise = delta_noise * dream_lambda + _noisy_latents = noisy_latents.add(sqrt_one_minus_alphas_cumprod * delta_noise) + _target = target.add(delta_noise) + elif noise_scheduler.config["prediction_type"] == "v_prediction": + raise NotImplementedError("DREAM has not been implemented for v-prediction") + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config['prediction_type']}") + + return _noisy_latents, _target + + def cast_training_params(model: Union[nn.Cell, List[nn.Cell]], dtype=ms.float32): if not isinstance(model, list): model = [model] diff --git a/mindone/diffusers/utils/__init__.py b/mindone/diffusers/utils/__init__.py index 7f8cb17f43..cfcfef5555 100644 --- a/mindone/diffusers/utils/__init__.py +++ b/mindone/diffusers/utils/__init__.py @@ -21,20 +21,33 @@ HUGGINGFACE_CO_RESOLVE_ENDPOINT, ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME, + SAFE_WEIGHTS_INDEX_NAME, SAFETENSORS_FILE_EXTENSION, SAFETENSORS_WEIGHTS_NAME, + WEIGHTS_INDEX_NAME, WEIGHTS_NAME, ) from .deprecation_utils import deprecate from .dynamic_modules_utils import get_class_from_dynamic_module from .export_utils import export_to_gif, export_to_obj, export_to_ply, export_to_video -from .hub_utils import PushToHubMixin, _add_variant, _get_model_file, extract_commit_hash, http_user_agent +from .hub_utils import ( + PushToHubMixin, + _add_variant, + _get_checkpoint_shard_files, + _get_model_file, + extract_commit_hash, + http_user_agent, +) from .import_utils import ( BACKENDS_MAPPING, _LazyModule, is_bs4_available, is_ftfy_available, + is_matplotlib_available, is_opencv_available, + is_peft_version, + is_scipy_available, + is_transformers_available, maybe_import_module_in_mindone, ) from .loading_utils import load_image diff --git a/mindone/diffusers/utils/constants.py b/mindone/diffusers/utils/constants.py index e616d7d10a..987a41bc25 100644 --- a/mindone/diffusers/utils/constants.py +++ b/mindone/diffusers/utils/constants.py @@ -17,9 +17,11 @@ CONFIG_NAME = "config.json" WEIGHTS_NAME = "diffusion_pytorch_model.bin" +WEIGHTS_INDEX_NAME = "diffusion_pytorch_model.bin.index.json" FLAX_WEIGHTS_NAME = "diffusion_flax_model.msgpack" ONNX_WEIGHTS_NAME = "model.onnx" SAFETENSORS_WEIGHTS_NAME = "diffusion_pytorch_model.safetensors" +SAFE_WEIGHTS_INDEX_NAME = "diffusion_pytorch_model.safetensors.index.json" SAFETENSORS_FILE_EXTENSION = "safetensors" ONNX_EXTERNAL_WEIGHTS_NAME = "weights.pb" HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HF_ENDPOINT", "https://huggingface.co") diff --git a/mindone/diffusers/utils/dynamic_modules_utils.py b/mindone/diffusers/utils/dynamic_modules_utils.py index 0d191c01f2..31abb543c3 100644 --- a/mindone/diffusers/utils/dynamic_modules_utils.py +++ b/mindone/diffusers/utils/dynamic_modules_utils.py @@ -24,20 +24,18 @@ from typing import Dict, Optional, Union from urllib import request -from huggingface_hub import cached_download, hf_hub_download, model_info -from huggingface_hub.utils import validate_hf_hub_args +from huggingface_hub import hf_hub_download, model_info +from huggingface_hub.utils import RevisionNotFoundError, validate_hf_hub_args from packaging import version from .. import __version__ from . import DIFFUSERS_DYNAMIC_MODULE_NAME, HF_MODULES_CACHE, logging -COMMUNITY_PIPELINES_URL = ( - "https://raw.githubusercontent.com/huggingface/diffusers/{revision}/examples/community/{pipeline}.py" -) - - logger = logging.get_logger(__name__) # pylint: disable=invalid-name +# See https://huggingface.co/datasets/diffusers/community-pipelines-mirror +COMMUNITY_PIPELINES_MIRROR_ID = "diffusers/community-pipelines-mirror" + def get_diffusers_versions(): url = "https://pypi.org/pypi/diffusers/json" @@ -199,7 +197,7 @@ def get_cached_module_file( module_file: str, cache_dir: Optional[Union[str, os.PathLike]] = None, force_download: bool = False, - resume_download: bool = False, + resume_download: Optional[bool] = None, proxies: Optional[Dict[str, str]] = None, token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, @@ -227,8 +225,9 @@ def get_cached_module_file( force_download (`bool`, *optional*, defaults to `False`): Whether or not to force to (re-)download the configuration files and override the cached versions if they exist. - resume_download (`bool`, *optional*, defaults to `False`): - Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists. + resume_download: + Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1 + of Diffusers. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. @@ -244,8 +243,8 @@ def get_cached_module_file( - You may pass a token in `token` if you are not logged in (`huggingface-cli login`) and want to use private - or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models). + You may pass a token in `token` if you are not logged in (`huggingface-cli login`) and want to use private or + [gated models](https://huggingface.co/docs/hub/models-gated#gated-models). @@ -279,20 +278,24 @@ def get_cached_module_file( f" {', '.join(available_versions + ['main'])}." ) - # community pipeline on GitHub - github_url = COMMUNITY_PIPELINES_URL.format(revision=revision, pipeline=pretrained_model_name_or_path) try: - resolved_module_file = cached_download( - github_url, + resolved_module_file = hf_hub_download( + repo_id=COMMUNITY_PIPELINES_MIRROR_ID, + repo_type="dataset", + filename=f"{revision}/{pretrained_model_name_or_path}.py", cache_dir=cache_dir, force_download=force_download, proxies=proxies, - resume_download=resume_download, local_files_only=local_files_only, - token=False, ) submodule = "git" module_file = pretrained_model_name_or_path + ".py" + except RevisionNotFoundError as e: + raise EnvironmentError( + f"Revision '{revision}' not found in the community pipelines mirror. Check available revisions on" + " https://huggingface.co/datasets/diffusers/community-pipelines-mirror/tree/main." + " If you don't find the revision you are looking for, please open an issue on https://github.com/huggingface/diffusers/issues." + ) from e except EnvironmentError: logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.") raise @@ -327,6 +330,11 @@ def get_cached_module_file( # The only reason we do the copy is to avoid putting too many folders in sys.path. shutil.copy(resolved_module_file, submodule_path / module_file) for module_needed in modules_needed: + if len(module_needed.split(".")) == 2: + module_needed = "/".join(module_needed.split(".")) + module_folder = module_needed.split("/")[0] + if not os.path.exists(submodule_path / module_folder): + os.makedirs(submodule_path / module_folder) module_needed = f"{module_needed}.py" shutil.copy(os.path.join(pretrained_model_name_or_path, module_needed), submodule_path / module_needed) else: @@ -341,9 +349,16 @@ def get_cached_module_file( create_dynamic_module(full_submodule) if not (submodule_path / module_file).exists(): + if len(module_file.split("/")) == 2: + module_folder = module_file.split("/")[0] + if not os.path.exists(submodule_path / module_folder): + os.makedirs(submodule_path / module_folder) shutil.copy(resolved_module_file, submodule_path / module_file) + # Make sure we also have every file with relative for module_needed in modules_needed: + if len(module_needed.split(".")) == 2: + module_needed = "/".join(module_needed.split(".")) if not (submodule_path / module_needed).exists(): get_cached_module_file( pretrained_model_name_or_path, @@ -366,7 +381,7 @@ def get_class_from_dynamic_module( class_name: Optional[str] = None, cache_dir: Optional[Union[str, os.PathLike]] = None, force_download: bool = False, - resume_download: bool = False, + resume_download: Optional[bool] = None, proxies: Optional[Dict[str, str]] = None, token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, @@ -403,8 +418,9 @@ def get_class_from_dynamic_module( force_download (`bool`, *optional*, defaults to `False`): Whether or not to force to (re-)download the configuration files and override the cached versions if they exist. - resume_download (`bool`, *optional*, defaults to `False`): - Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists. + resume_download: + Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1 of + Diffusers. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. @@ -420,8 +436,8 @@ def get_class_from_dynamic_module( - You may pass a token in `token` if you are not logged in (`huggingface-cli login`) and want to use private - or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models). + You may pass a token in `token` if you are not logged in (`huggingface-cli login`) and want to use private or + [gated models](https://huggingface.co/docs/hub/models-gated#gated-models). diff --git a/mindone/diffusers/utils/hub_utils.py b/mindone/diffusers/utils/hub_utils.py index 3e40057560..756cf27b56 100644 --- a/mindone/diffusers/utils/hub_utils.py +++ b/mindone/diffusers/utils/hub_utils.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import json import os import re import sys @@ -22,7 +23,15 @@ from typing import Dict, List, Optional, Union from uuid import uuid4 -from huggingface_hub import ModelCard, ModelCardData, create_repo, hf_hub_download, upload_folder +from huggingface_hub import ( + ModelCard, + ModelCardData, + create_repo, + hf_hub_download, + model_info, + snapshot_download, + upload_folder, +) from huggingface_hub.constants import HF_HUB_CACHE from huggingface_hub.file_download import REGEX_COMMIT_HASH from huggingface_hub.utils import ( @@ -72,7 +81,8 @@ def load_or_create_model_card( repo_id_or_path (`str`): The repo id (e.g., "runwayml/stable-diffusion-v1-5") or local path where to look for the model card. token (`str`, *optional*): - Authentication token. Will default to the stored token. See https://huggingface.co/settings/token for more details. + Authentication token. Will default to the stored token. See https://huggingface.co/settings/token for more + details. is_pipeline (`bool`): Boolean to indicate if we're adding tag to a [`DiffusionPipeline`]. from_training: (`bool`): Boolean flag to denote if the model card is being created from a training script. @@ -242,7 +252,7 @@ def _get_model_file( cache_dir: Optional[str] = None, force_download: bool = False, proxies: Optional[Dict] = None, - resume_download: bool = False, + resume_download: Optional[bool] = None, local_files_only: bool = False, token: Optional[str] = None, user_agent: Optional[Union[Dict, str]] = None, @@ -352,6 +362,109 @@ def _get_model_file( ) +# Adapted from +# https://github.com/huggingface/transformers/blob/1360801a69c0b169e3efdbb0cd05d9a0e72bfb70/src/transformers/utils/hub.py#L976 +# Differences are in parallelization of shard downloads and checking if shards are present. + + +def _check_if_shards_exist_locally(local_dir, subfolder, original_shard_filenames): + shards_path = os.path.join(local_dir, subfolder) + shard_filenames = [os.path.join(shards_path, f) for f in original_shard_filenames] + for shard_file in shard_filenames: + if not os.path.exists(shard_file): + raise ValueError( + f"{shards_path} does not appear to have a file named {shard_file} which is " + "required according to the checkpoint index." + ) + + +def _get_checkpoint_shard_files( + pretrained_model_name_or_path, + index_filename, + cache_dir=None, + proxies=None, + resume_download=False, + local_files_only=False, + token=None, + user_agent=None, + revision=None, + subfolder="", +): + """ + For a given model: + + - download and cache all the shards of a sharded checkpoint if `pretrained_model_name_or_path` is a model ID on the + Hub + - returns the list of paths to all the shards, as well as some metadata. + + For the description of each arg, see [`PreTrainedModel.from_pretrained`]. `index_filename` is the full path to the + index (downloaded and cached if `pretrained_model_name_or_path` is a model ID on the Hub). + """ + if not os.path.isfile(index_filename): + raise ValueError(f"Can't find a checkpoint index ({index_filename}) in {pretrained_model_name_or_path}.") + + with open(index_filename, "r") as f: + index = json.loads(f.read()) + + original_shard_filenames = sorted(set(index["weight_map"].values())) + sharded_metadata = index["metadata"] + sharded_metadata["all_checkpoint_keys"] = list(index["weight_map"].keys()) + sharded_metadata["weight_map"] = index["weight_map"].copy() + shards_path = os.path.join(pretrained_model_name_or_path, subfolder) + + # First, let's deal with local folder. + if os.path.isdir(pretrained_model_name_or_path): + _check_if_shards_exist_locally( + pretrained_model_name_or_path, subfolder=subfolder, original_shard_filenames=original_shard_filenames + ) + return pretrained_model_name_or_path, sharded_metadata + + # At this stage pretrained_model_name_or_path is a model identifier on the Hub + allow_patterns = original_shard_filenames + ignore_patterns = ["*.json", "*.md"] + if not local_files_only: + # `model_info` call must guarded with the above condition. + model_files_info = model_info(pretrained_model_name_or_path) + for shard_file in original_shard_filenames: + shard_file_present = any(shard_file in k.rfilename for k in model_files_info.siblings) + if not shard_file_present: + raise EnvironmentError( + f"{shards_path} does not appear to have a file named {shard_file} which is " + "required according to the checkpoint index." + ) + + try: + # Load from URL + cached_folder = snapshot_download( + pretrained_model_name_or_path, + cache_dir=cache_dir, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + user_agent=user_agent, + ) + + # We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so + # we don't have to catch them here. We have also dealt with EntryNotFoundError. + except HTTPError as e: + raise EnvironmentError( + f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load {pretrained_model_name_or_path}. You should try" + " again after checking your internet connection." + ) from e + + # If `local_files_only=True`, `cached_folder` may not contain all the shard files. + if local_files_only: + _check_if_shards_exist_locally( + local_dir=cache_dir, subfolder=subfolder, original_shard_filenames=original_shard_filenames + ) + + return cached_folder, sharded_metadata + + class PushToHubMixin: """ A Mixin to push a model, scheduler, or pipeline to the Hugging Face Hub. diff --git a/mindone/diffusers/utils/import_utils.py b/mindone/diffusers/utils/import_utils.py index cebdb43d43..3d4d642324 100644 --- a/mindone/diffusers/utils/import_utils.py +++ b/mindone/diffusers/utils/import_utils.py @@ -15,14 +15,16 @@ Import utilities: Utilities related to imports and our lazy inits. """ import importlib.util +import operator as op import os import sys from collections import OrderedDict from itertools import chain from types import ModuleType -from typing import Any +from typing import Any, Union from huggingface_hub.utils import is_jinja_available # noqa: F401 +from packaging.version import Version, parse from . import logging @@ -35,6 +37,15 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name +STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt} + +_transformers_available = importlib.util.find_spec("transformers") is not None +try: + _transformers_version = importlib_metadata.version("transformers") + logger.debug(f"Successfully imported transformers version {_transformers_version}") +except importlib_metadata.PackageNotFoundError: + _transformers_available = False + # (sayakpaul): importlib.util.find_spec("opencv-python") returns None even when it's installed. # _opencv_available = importlib.util.find_spec("opencv-python") is not None try: @@ -58,6 +69,14 @@ _opencv_available = False +_scipy_available = importlib.util.find_spec("scipy") is not None +try: + _scipy_version = importlib_metadata.version("scipy") + logger.debug(f"Successfully imported scipy version {_scipy_version}") +except importlib_metadata.PackageNotFoundError: + _scipy_available = False + + _ftfy_available = importlib.util.find_spec("ftfy") is not None try: _ftfy_version = importlib_metadata.version("ftfy") @@ -75,10 +94,26 @@ _bs4_available = False +_matplotlib_available = importlib.util.find_spec("matplotlib") is not None +try: + _matplotlib_version = importlib_metadata.version("matplotlib") + logger.debug(f"Successfully imported matplotlib version {_matplotlib_version}") +except importlib_metadata.PackageNotFoundError: + _matplotlib_available = False + + +def is_transformers_available(): + return _transformers_available + + def is_opencv_available(): return _opencv_available +def is_scipy_available(): + return _scipy_available + + def is_ftfy_available(): return _ftfy_available @@ -87,12 +122,31 @@ def is_bs4_available(): return _bs4_available +def is_matplotlib_available(): + return _matplotlib_available + + # docstyle-ignore OPENCV_IMPORT_ERROR = """ {0} requires the OpenCV library but it was not found in your environment. You can install it with pip: `pip install opencv-python` """ + +# docstyle-ignore +SCIPY_IMPORT_ERROR = """ +{0} requires the scipy library but it was not found in your environment. You can install it with pip: `pip install +scipy` +""" + + +# docstyle-ignore +TRANSFORMERS_IMPORT_ERROR = """ +{0} requires the transformers library but it was not found in your environment. You can install it with pip: `pip +install transformers` +""" + + # docstyle-ignore BS4_IMPORT_ERROR = """ {0} requires the Beautiful Soup library but it was not found in your environment. You can install it with pip: @@ -111,11 +165,49 @@ def is_bs4_available(): [ ("bs4", (is_bs4_available, BS4_IMPORT_ERROR)), ("opencv", (is_opencv_available, OPENCV_IMPORT_ERROR)), + ("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)), + ("transformers", (is_transformers_available, TRANSFORMERS_IMPORT_ERROR)), ("ftfy", (is_ftfy_available, FTFY_IMPORT_ERROR)), ] ) +# This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L319 +def compare_versions(library_or_version: Union[str, Version], operation: str, requirement_version: str): + """ + Args: + Compares a library version to some requirement using a given operation. + library_or_version (`str` or `packaging.version.Version`): + A library name or a version to check. + operation (`str`): + A string representation of an operator, such as `">"` or `"<="`. + requirement_version (`str`): + The version to compare the library version against + """ + if operation not in STR_OPERATION_TO_FUNC.keys(): + raise ValueError(f"`operation` must be one of {list(STR_OPERATION_TO_FUNC.keys())}, received {operation}") + operation = STR_OPERATION_TO_FUNC[operation] + if isinstance(library_or_version, str): + library_or_version = parse(importlib_metadata.version(library_or_version)) + return operation(library_or_version, parse(requirement_version)) + + +def is_peft_version(operation: str, version: str): + """ + Args: + Compares the current PEFT version to a given reference with an operation. + operation (`str`): + A string representation of an operator, such as `">"` or `"<="` + version (`str`): + A version string + """ + from mindone.diffusers._peft import __version__ as _peft_version + + if not _peft_version: + return False + return compare_versions(parse(_peft_version), operation, version) + + def maybe_import_module_in_mindone(module_name: str, force_original: bool = False): if force_original: return importlib.import_module(module_name) diff --git a/mindone/diffusers/utils/loading_utils.py b/mindone/diffusers/utils/loading_utils.py index 18f6ead64c..aa087e9817 100644 --- a/mindone/diffusers/utils/loading_utils.py +++ b/mindone/diffusers/utils/loading_utils.py @@ -16,8 +16,8 @@ def load_image( image (`str` or `PIL.Image.Image`): The image to convert to the PIL Image format. convert_method (Callable[[PIL.Image.Image], PIL.Image.Image], optional): - A conversion method to apply to the image after loading it. - When set to `None` the image will be converted "RGB". + A conversion method to apply to the image after loading it. When set to `None` the image will be converted + "RGB". Returns: `PIL.Image.Image`: diff --git a/mindone/diffusers/utils/logging.py b/mindone/diffusers/utils/logging.py index 782c75e6c1..236fb1c2f9 100644 --- a/mindone/diffusers/utils/logging.py +++ b/mindone/diffusers/utils/logging.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" Logging utilities.""" +"""Logging utilities.""" import logging import os import sys @@ -71,7 +71,9 @@ def _configure_library_root_logger() -> None: # This library has already configured the library root logger. return _default_handler = logging.StreamHandler() # Set sys.stderr as stream. - _default_handler.flush = sys.stderr.flush + + if sys.stderr: # only if sys.stderr exists, e.g. when not using pythonw in windows + _default_handler.flush = sys.stderr.flush # Apply our default configuration to the library root logger. library_root_logger = _get_library_root_logger() diff --git a/mindone/diffusers/utils/mindspore_utils.py b/mindone/diffusers/utils/mindspore_utils.py index 3277c0f237..2e192261b2 100644 --- a/mindone/diffusers/utils/mindspore_utils.py +++ b/mindone/diffusers/utils/mindspore_utils.py @@ -14,6 +14,7 @@ """ MindSpore utilities: Utilities related to MindSpore """ + from typing import List, Optional, Tuple, Union import numpy as np diff --git a/mindone/diffusers/utils/peft_utils.py b/mindone/diffusers/utils/peft_utils.py index a53fdbfdc3..def4044e35 100644 --- a/mindone/diffusers/utils/peft_utils.py +++ b/mindone/diffusers/utils/peft_utils.py @@ -14,6 +14,7 @@ """ PEFT utilities: Utilities related to peft library """ + import collections from typing import Optional @@ -47,6 +48,9 @@ def scale_lora_layers(model, weight): """ from mindone.diffusers._peft.tuners.tuners_utils import BaseTunerLayer + if weight == 1.0: + return + for _, module in model.cells_and_names(): if isinstance(module, BaseTunerLayer): module.scale_layer(weight) @@ -66,6 +70,9 @@ def unscale_lora_layers(model, weight: Optional[float] = None): """ from mindone.diffusers._peft.tuners.tuners_utils import BaseTunerLayer + if weight == 1.0: + return + for _, module in model.cells_and_names(): if isinstance(module, BaseTunerLayer): if weight is not None and weight != 0: @@ -108,6 +115,7 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True # layer names without the Diffusers specific target_modules = list({name.split(".lora")[0] for name in peft_state_dict.keys()}) + use_dora = any("lora_magnitude_vector" in k for k in peft_state_dict) lora_config_kwargs = { "r": r, @@ -115,6 +123,7 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True "rank_pattern": rank_pattern, "alpha_pattern": alpha_pattern, "target_modules": target_modules, + "use_dora": use_dora, } return lora_config_kwargs @@ -163,16 +172,32 @@ def delete_adapter_layers(model, adapter_name): def set_weights_and_activate_adapters(model, adapter_names, weights): from mindone.diffusers._peft.tuners.tuners_utils import BaseTunerLayer + def get_module_weight(weight_for_adapter, module_name): + if not isinstance(weight_for_adapter, dict): + # If weight_for_adapter is a single number, always return it. + return weight_for_adapter + + for layer_name, weight_ in weight_for_adapter.items(): + if layer_name in module_name: + return weight_ + + parts = module_name.split(".") + # e.g. key = "down_blocks.1.attentions.0" + key = f"{parts[0]}.{parts[1]}.attentions.{parts[3]}" + block_weight = weight_for_adapter.get(key, 1.0) + + return block_weight + # iterate over each adapter, make it active and set the corresponding scaling weight for adapter_name, weight in zip(adapter_names, weights): - for _, module in model.cells_and_names(): + for module_name, module in model.cells_and_names(): if isinstance(module, BaseTunerLayer): # For backward compatibility with previous PEFT versions if hasattr(module, "set_adapter"): module.set_adapter(adapter_name) else: raise RuntimeError("'BaseTunerLayer' object has no attribute 'set_adapter'") - module.set_scale(adapter_name, weight) + module.set_scale(adapter_name, get_module_weight(weight, module_name)) # set multiple active adapters for _, module in model.cells_and_names(): diff --git a/mindone/diffusers/utils/state_dict_utils.py b/mindone/diffusers/utils/state_dict_utils.py index 0d96922a0b..dc07b09a31 100644 --- a/mindone/diffusers/utils/state_dict_utils.py +++ b/mindone/diffusers/utils/state_dict_utils.py @@ -14,6 +14,7 @@ """ State dict utilities: utility methods for converting state dicts easily """ + import enum from .logging import get_logger @@ -45,6 +46,7 @@ class StateDictType(enum.Enum): ".to_v_lora.up": ".to_v.lora_B", ".lora.up": ".lora_B", ".lora.down": ".lora_A", + ".to_out.lora_magnitude_vector": ".to_out.0.lora_magnitude_vector", } @@ -59,6 +61,8 @@ class StateDictType(enum.Enum): ".out_proj.lora_linear_layer.down": ".out_proj.lora_A", ".lora_linear_layer.up": ".lora_B", ".lora_linear_layer.down": ".lora_A", + "text_projection.lora.down.weight": "text_projection.lora_A.weight", + "text_projection.lora.up.weight": "text_projection.lora_B.weight", } DIFFUSERS_OLD_TO_PEFT = { @@ -102,6 +106,10 @@ class StateDictType(enum.Enum): ".to_v_lora.down": ".v_proj.lora_linear_layer.down", ".to_out_lora.up": ".out_proj.lora_linear_layer.up", ".to_out_lora.down": ".out_proj.lora_linear_layer.down", + ".to_k.lora_magnitude_vector": ".k_proj.lora_magnitude_vector", + ".to_v.lora_magnitude_vector": ".v_proj.lora_magnitude_vector", + ".to_q.lora_magnitude_vector": ".q_proj.lora_magnitude_vector", + ".to_out.lora_magnitude_vector": ".out_proj.lora_magnitude_vector", } PEFT_TO_KOHYA_SS = { @@ -246,8 +254,8 @@ def convert_unet_state_dict_to_peft(state_dict): def convert_all_state_dict_to_peft(state_dict): r""" - Attempts to first `convert_state_dict_to_peft`, and if it doesn't detect `lora_linear_layer` - for a valid `DIFFUSERS` LoRA for example, attempts to exclusively convert the Unet `convert_unet_state_dict_to_peft` + Attempts to first `convert_state_dict_to_peft`, and if it doesn't detect `lora_linear_layer` for a valid + `DIFFUSERS` LoRA for example, attempts to exclusively convert the Unet `convert_unet_state_dict_to_peft` """ try: peft_dict = convert_state_dict_to_peft(state_dict) diff --git a/mindone/diffusers/video_processor.py b/mindone/diffusers/video_processor.py new file mode 100644 index 0000000000..f5ce203435 --- /dev/null +++ b/mindone/diffusers/video_processor.py @@ -0,0 +1,115 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from typing import List, Optional, Union + +import numpy as np +import PIL + +import mindspore as ms +from mindspore import ops + +from .image_processor import VaeImageProcessor, is_valid_image, is_valid_image_imagelist + + +class VideoProcessor(VaeImageProcessor): + r"""Simple video processor.""" + + def preprocess_video(self, video, height: Optional[int] = None, width: Optional[int] = None) -> ms.Tensor: + r""" + Preprocesses input video(s). + + Args: + video (`List[PIL.Image]`, `List[List[PIL.Image]]`, `ms.Tensor`, `np.array`, `List[ms.Tensor]`, `List[np.array]`): + The input video. It can be one of the following: + * List of the PIL images. + * List of list of PIL images. + * 4D MindSpore tensors (expected shape for each tensor `(num_frames, num_channels, height, width)`). + * 4D NumPy arrays (expected shape for each array `(num_frames, height, width, num_channels)`). + * List of 4D MindSpore tensors (expected shape for each tensor `(num_frames, num_channels, height, + width)`). + * List of 4D NumPy arrays (expected shape for each array `(num_frames, height, width, num_channels)`). + * 5D NumPy arrays: expected shape for each array `(batch_size, num_frames, height, width, + num_channels)`. + * 5D MindSpore tensors: expected shape for each array `(batch_size, num_frames, num_channels, height, + width)`. + height (`int`, *optional*, defaults to `None`): + The height in preprocessed frames of the video. If `None`, will use the `get_default_height_width()` to + get default height. + width (`int`, *optional*`, defaults to `None`): + The width in preprocessed frames of the video. If `None`, will use get_default_height_width()` to get + the default width. + """ + if isinstance(video, list) and isinstance(video[0], np.ndarray) and video[0].ndim == 5: + warnings.warn( + "Passing `video` as a list of 5d np.ndarray is deprecated." + "Please concatenate the list along the batch dimension and pass it as a single 5d np.ndarray", + FutureWarning, + ) + video = np.concatenate(video, axis=0) + if isinstance(video, list) and isinstance(video[0], ms.Tensor) and video[0].ndim == 5: + warnings.warn( + "Passing `video` as a list of 5d ms.Tensor is deprecated." + "Please concatenate the list along the batch dimension and pass it as a single 5d ms.Tensor", + FutureWarning, + ) + video = ops.cat(video, axis=0) + + # ensure the input is a list of videos: + # - if it is a batch of videos (5d ms.Tensor or np.ndarray), it is converted to a list of videos (a list of 4d ms.Tensor or np.ndarray) + # - if it is is a single video, it is convereted to a list of one video. + if isinstance(video, (np.ndarray, ms.Tensor)) and video.ndim == 5: + video = list(video) + elif isinstance(video, list) and is_valid_image(video[0]) or is_valid_image_imagelist(video): + video = [video] + elif isinstance(video, list) and is_valid_image_imagelist(video[0]): + video = video + else: + raise ValueError( + "Input is in incorrect format. Currently, we only support numpy.ndarray, ms.Tensor, PIL.Image.Image" + ) + + video = ops.stack([self.preprocess(img, height=height, width=width) for img in video], axis=0) + + # move the number of channels before the number of frames. + video = video.permute(0, 2, 1, 3, 4) + + return video + + def postprocess_video( + self, video: ms.Tensor, output_type: str = "np" + ) -> Union[np.ndarray, ms.Tensor, List[PIL.Image.Image]]: + r""" + Converts a video tensor to a list of frames for export. + + Args: + video (`ms.Tensor`): The video as a tensor. + output_type (`str`, defaults to `"np"`): Output type of the postprocessed `video` tensor. + """ + batch_size = video.shape[0] + outputs = [] + for batch_idx in range(batch_size): + batch_vid = video[batch_idx].permute(1, 0, 2, 3) + batch_output = self.postprocess(batch_vid, output_type) + outputs.append(batch_output) + + if output_type == "np": + outputs = np.stack(outputs) + elif output_type == "pt": + outputs = ops.stack(outputs) + elif not output_type == "pil": + raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']") + + return outputs diff --git a/tests/diffusers/models/modeling_test_utils.py b/tests/diffusers/models/modeling_test_utils.py index 10f27f1690..0cc34bc654 100644 --- a/tests/diffusers/models/modeling_test_utils.py +++ b/tests/diffusers/models/modeling_test_utils.py @@ -30,7 +30,7 @@ def get_pt2ms_mappings(m): for name, cell in m.cells_and_names(): if isinstance(cell, (nn.Conv1d, nn.Conv1dTranspose)): mappings[f"{name}.weight"] = f"{name}.weight", lambda x: ms.Parameter( - ops.expand_dims(x, axis=-2), name=x.name + ops.expand_dims(x, axis=-2), name=f"{name}.weight" ) elif isinstance(cell, nn.Embedding): mappings[f"{name}.weight"] = f"{name}.embedding_table", lambda x: x diff --git a/tests/diffusers/models/modules_test_cases.py b/tests/diffusers/models/modules_test_cases.py index f2b175c037..946063ecc9 100644 --- a/tests/diffusers/models/modules_test_cases.py +++ b/tests/diffusers/models/modules_test_cases.py @@ -449,27 +449,6 @@ # autoencoders -# VQModel: volatile in fp16(fyi: 2%-20% diff when torch.fp16 vs torch.fp32) -VQ_CASES = [ - [ - "VQModel", # volatile with random init: 2%-20% diff when torch.float16 v.s. torch.float32 - "diffusers.models.vq_model.VQModel", - "mindone.diffusers.models.vq_model.VQModel", - (), - { - "block_out_channels": [32, 64], - "in_channels": 3, - "out_channels": 3, - "down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"], - "up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"], - "latent_channels": 3, - }, - (), - {"sample": np.random.randn(4, 3, 32, 32).astype(np.float32), "return_dict": False}, - ], -] - - VAE_CASES = [ [ "AutoencoderKL", @@ -557,12 +536,25 @@ "return_dict": False, }, ], + [ + "VQModel", # volatile with random init: 2%-20% diff when torch.float16 v.s. torch.float32 + "diffusers.models.autoencoders.vq_model.VQModel", + "mindone.diffusers.models.autoencoders.vq_model.VQModel", + (), + { + "block_out_channels": [32, 64], + "in_channels": 3, + "out_channels": 3, + "down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"], + "up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"], + "latent_channels": 3, + }, + (), + {"sample": np.random.randn(4, 3, 32, 32).astype(np.float32), "return_dict": False}, + ], ] -AE_CASES = VQ_CASES + VAE_CASES - - # transformers TRANSFORMER2D_CASES = [ [ @@ -653,7 +645,108 @@ ] -TRANSFORMERS_CASES = TRANSFORMER2D_CASES + PRIOR_TRANSFORMER_CASES +DIT_TRANSFORMER2D_CASES = [ + [ + "DiTTransformer2DModel", + "diffusers.models.transformers.dit_transformer_2d.DiTTransformer2DModel", + "mindone.diffusers.models.transformers.dit_transformer_2d.DiTTransformer2DModel", + (), + { + "in_channels": 4, + "out_channels": 8, + "activation_fn": "gelu-approximate", + "num_attention_heads": 2, + "attention_head_dim": 4, + "attention_bias": True, + "num_layers": 1, + "norm_type": "ada_norm_zero", + "num_embeds_ada_norm": 8, + "patch_size": 2, + "sample_size": 8, + }, + (), + dict( + hidden_states=np.random.randn(4, 4, 8, 8), + timestep=np.random.randint(0, 1000, size=(4,)), + class_labels=np.random.randint(0, 4, size=(4,)), + return_dict=False, + ), + ], +] + + +PIXART_TRANSFORMER2D_CASES = [ + [ + "PixArtTransformer2DModel", + "diffusers.models.transformers.pixart_transformer_2d.PixArtTransformer2DModel", + "mindone.diffusers.models.transformers.pixart_transformer_2d.PixArtTransformer2DModel", + (), + { + "sample_size": 8, + "num_layers": 1, + "patch_size": 2, + "attention_head_dim": 2, + "num_attention_heads": 2, + "in_channels": 4, + "cross_attention_dim": 8, + "out_channels": 8, + "attention_bias": True, + "activation_fn": "gelu-approximate", + "num_embeds_ada_norm": 8, + "norm_type": "ada_norm_single", + "norm_elementwise_affine": False, + "norm_eps": 1e-6, + "use_additional_conditions": False, + "caption_channels": None, + }, + (), + { + "hidden_states": np.random.randn(4, 4, 8, 8), + "timestep": np.random.randint(0, 1000, size=(4,)), + "encoder_hidden_states": np.random.randn(4, 8, 8), + "added_cond_kwargs": {"aspect_ratio": None, "resolution": None}, + "return_dict": False, + }, + ], +] + + +SD3_TRANSFORMER2D_CASES = [ + [ + "SD3Transformer2DModel", + "diffusers.models.transformers.transformer_sd3.SD3Transformer2DModel", + "mindone.diffusers.models.transformers.transformer_sd3.SD3Transformer2DModel", + (), + { + "sample_size": 32, + "patch_size": 1, + "in_channels": 4, + "num_layers": 1, + "attention_head_dim": 8, + "num_attention_heads": 4, + "caption_projection_dim": 32, + "joint_attention_dim": 32, + "pooled_projection_dim": 64, + "out_channels": 4, + }, + (), + { + "hidden_states": np.random.randn(2, 4, 32, 32), + "encoder_hidden_states": np.random.randn(2, 154, 32), + "pooled_projections": np.random.randn(2, 64), + "timestep": np.random.randint(0, 1000, size=(2,)), + }, + ], +] + + +TRANSFORMERS_CASES = ( + DIT_TRANSFORMER2D_CASES + + PIXART_TRANSFORMER2D_CASES + + PRIOR_TRANSFORMER_CASES + + SD3_TRANSFORMER2D_CASES + + TRANSFORMER2D_CASES +) # unet @@ -715,6 +808,62 @@ ] +UNET2D_CASES = [ + [ + "UNet2DModel", + "diffusers.models.unets.unet_2d.UNet2DModel", + "mindone.diffusers.models.unets.unet_2d.UNet2DModel", + (), + { + "block_out_channels": (4, 8), + "norm_num_groups": 2, + "down_block_types": ("DownBlock2D", "AttnDownBlock2D"), + "up_block_types": ("AttnUpBlock2D", "UpBlock2D"), + "attention_head_dim": 3, + "out_channels": 3, + "in_channels": 3, + "layers_per_block": 2, + "sample_size": 32, + }, + (), + { + "sample": np.random.randn(4, 3, 32, 32), + "timestep": np.array([10]).astype(np.int32), + "return_dict": False, + }, + ], +] + + +UNET2D_CONDITION_CASES = [ + [ + "UNet2DConditionModel", + "diffusers.models.unets.unet_2d_condition.UNet2DConditionModel", + "mindone.diffusers.models.unets.unet_2d_condition.UNet2DConditionModel", + (), + { + "block_out_channels": (4, 8), + "norm_num_groups": 4, + "down_block_types": ("CrossAttnDownBlock2D", "DownBlock2D"), + "up_block_types": ("UpBlock2D", "CrossAttnUpBlock2D"), + "cross_attention_dim": 8, + "attention_head_dim": 2, + "out_channels": 4, + "in_channels": 4, + "layers_per_block": 1, + "sample_size": 16, + }, + (), + { + "sample": np.random.randn(4, 4, 16, 16), + "timestep": np.array([10]).astype(np.int32), + "encoder_hidden_states": np.random.randn(4, 4, 8), + "return_dict": False, + }, + ], +] + + UVIT2D_CASES = [ [ "UVit2DModel", @@ -926,23 +1075,23 @@ "mindone.diffusers.models.unets.unet_stable_cascade.StableCascadeUNet", (), dict( - block_out_channels=(2048, 2048), + block_out_channels=(96, 96), block_types_per_layer=( ("SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"), ("SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"), ), - clip_image_in_channels=768, + clip_image_in_channels=32, clip_seq=4, - clip_text_in_channels=1280, - clip_text_pooled_in_channels=1280, - conditioning_dim=2048, + clip_text_in_channels=64, + clip_text_pooled_in_channels=64, + conditioning_dim=96, down_blocks_repeat_mappers=(1, 1), - down_num_layers_per_block=(8, 24), + down_num_layers_per_block=(2, 2), dropout=(0.1, 0.1), effnet_in_channels=None, in_channels=16, kernel_size=3, - num_attention_heads=(32, 32), + num_attention_heads=(16, 16), out_channels=16, patch_size=1, pixel_mapper_in_channels=None, @@ -951,15 +1100,15 @@ timestep_conditioning_type=("sca", "crp"), timestep_ratio_embedding_dim=64, up_blocks_repeat_mappers=(1, 1), - up_num_layers_per_block=(24, 8), + up_num_layers_per_block=(2, 2), ), (), { "sample": np.random.randn(1, 16, 24, 24).astype(np.float32), "timestep_ratio": np.array([1], dtype=np.float32), - "clip_text_pooled": np.random.randn(1, 1, 1280).astype(np.float32), - "clip_text": np.random.randn(1, 77, 1280).astype(np.float32), - "clip_img": np.random.randn(1, 1, 768).astype(np.float32), + "clip_text_pooled": np.random.randn(1, 1, 64).astype(np.float32), + "clip_text": np.random.randn(1, 77, 64).astype(np.float32), + "clip_img": np.random.randn(1, 1, 32).astype(np.float32), "pixels": np.random.randn(1, 3, 8, 8).astype(np.float32), "return_dict": False, }, @@ -970,7 +1119,7 @@ "mindone.diffusers.models.unets.unet_stable_cascade.StableCascadeUNet", (), dict( - block_out_channels=(320, 640, 1280, 1280), + block_out_channels=(8, 16, 32, 32), block_types_per_layer=( ("SDCascadeResBlock", "SDCascadeTimestepBlock"), ("SDCascadeResBlock", "SDCascadeTimestepBlock"), @@ -980,10 +1129,10 @@ clip_image_in_channels=None, clip_seq=4, clip_text_in_channels=None, - clip_text_pooled_in_channels=1280, - conditioning_dim=1280, + clip_text_pooled_in_channels=32, + conditioning_dim=32, down_blocks_repeat_mappers=(1, 1, 1, 1), - down_num_layers_per_block=(2, 6, 28, 6), + down_num_layers_per_block=(1, 1, 1, 1), dropout=(0, 0, 0.1, 0.1), effnet_in_channels=16, in_channels=4, @@ -997,14 +1146,14 @@ timestep_conditioning_type=("sca",), timestep_ratio_embedding_dim=64, up_blocks_repeat_mappers=(3, 3, 2, 2), - up_num_layers_per_block=(6, 28, 6, 2), + up_num_layers_per_block=(1, 1, 1, 1), ), (), { - "sample": np.random.randn(1, 4, 256, 256).astype(np.float32), + "sample": np.random.randn(1, 4, 16, 16).astype(np.float32), "timestep_ratio": np.array([1], dtype=np.float32), - "clip_text_pooled": np.random.randn(1, 1, 1280).astype(np.float32), - "clip_text": np.random.randn(1, 77, 1280).astype(np.float32), + "clip_text_pooled": np.random.randn(1, 1, 32).astype(np.float32), + "clip_text": np.random.randn(1, 77, 32).astype(np.float32), "pixels": np.random.randn(1, 3, 8, 8).astype(np.float32), "return_dict": False, }, @@ -1012,8 +1161,44 @@ ] +UNET_CONTROLNET_XS_CASES = [ + [ + "UNetControlNetXSModel", + "diffusers.models.UNetControlNetXSModel", + "mindone.diffusers.models.UNetControlNetXSModel", + (), + { + "sample_size": 16, + "down_block_types": ("DownBlock2D", "CrossAttnDownBlock2D"), + "up_block_types": ("CrossAttnUpBlock2D", "UpBlock2D"), + "block_out_channels": (4, 8), + "cross_attention_dim": 8, + "transformer_layers_per_block": 1, + "num_attention_heads": 2, + "norm_num_groups": 4, + "upcast_attention": False, + "ctrl_block_out_channels": [2, 4], + "ctrl_num_attention_heads": 4, + "ctrl_max_norm_num_groups": 2, + "ctrl_conditioning_embedding_out_channels": (2, 2), + }, + (), + { + "sample": np.random.randn(4, 4, 16, 16), + "timestep": np.array([10]).astype(np.int64), + "encoder_hidden_states": np.random.randn(4, 4, 8), + "controlnet_cond": np.random.randn(4, 3, 32, 32), + "conditioning_scale": 1, + "return_dict": False, + }, + ], +] + + UNETS_CASES = ( UNET1D_CASES + + UNET2D_CASES + + UNET2D_CONDITION_CASES + UVIT2D_CASES + KANDINSKY3_CASES + UNET3D_CONDITION_MODEL_CASES @@ -1021,8 +1206,9 @@ + UNET_I2VGEN_XL_CASES + UNET_MOTION_MODEL_TEST + UNETSTABLECASCADE_CASES + + UNET_CONTROLNET_XS_CASES ) # all -ALL_CASES = LAYERS_CASES + AE_CASES + TRANSFORMERS_CASES + UNETS_CASES +ALL_CASES = LAYERS_CASES + VAE_CASES + TRANSFORMERS_CASES + UNETS_CASES diff --git a/tests/diffusers/models/test_generic_modules.py b/tests/diffusers/models/test_generic_modules.py index d7d7c1907f..dfcacd2ed4 100644 --- a/tests/diffusers/models/test_generic_modules.py +++ b/tests/diffusers/models/test_generic_modules.py @@ -36,10 +36,10 @@ @pytest.mark.parametrize( - "name,pt_module,ms_module,init_args,init_kwargs,inputs_args,inputs_kwargs", - ALL_CASES, + "name,pt_module,ms_module,init_args,init_kwargs,inputs_args,inputs_kwargs,dtype,mode", + [case + context for case in ALL_CASES for context in [["fp16", 0], ["fp16", 1], ["fp32", 0], ["fp32", 1]]], ) -def test_named_modules_with_graph_fp32( +def test_named_modules( name, pt_module, ms_module, @@ -47,124 +47,10 @@ def test_named_modules_with_graph_fp32( init_kwargs, inputs_args, inputs_kwargs, + dtype, + mode, ): - dtype = "fp32" - ms.set_context(mode=ms.GRAPH_MODE, jit_syntax_level=ms.STRICT) - - ( - pt_model, - ms_model, - pt_dtype, - ms_dtype, - ) = get_modules(pt_module, ms_module, dtype, *init_args, **init_kwargs) - pt_inputs_args, pt_inputs_kwargs, ms_inputs_args, ms_inputs_kwargs = generalized_parse_args( - pt_dtype, ms_dtype, *inputs_args, **inputs_kwargs - ) - - with torch.no_grad(): - pt_outputs = pt_model(*pt_inputs_args, **pt_inputs_kwargs) - ms_outputs = ms_model(*ms_inputs_args, **ms_inputs_kwargs) - - diffs = compute_diffs(pt_outputs, ms_outputs) - - assert ( - np.array(diffs) < THRESHOLD_FP32 - ).all(), f"Outputs({np.array(diffs).tolist()}) has diff bigger than {THRESHOLD_FP32}" - - -@pytest.mark.parametrize( - "name,pt_module,ms_module,init_args,init_kwargs,inputs_args,inputs_kwargs", - ALL_CASES, -) -def test_named_modules_with_graph_fp16( - name, - pt_module, - ms_module, - init_args, - init_kwargs, - inputs_args, - inputs_kwargs, -): - dtype = "fp16" - ms.set_context(mode=ms.GRAPH_MODE, jit_syntax_level=ms.STRICT) - - ( - pt_model, - ms_model, - pt_dtype, - ms_dtype, - ) = get_modules(pt_module, ms_module, dtype, *init_args, **init_kwargs) - pt_inputs_args, pt_inputs_kwargs, ms_inputs_args, ms_inputs_kwargs = generalized_parse_args( - pt_dtype, ms_dtype, *inputs_args, **inputs_kwargs - ) - - if "hidden_dtype" in inspect.signature(pt_model.forward).parameters: - pt_inputs_kwargs.update({"hidden_dtype": PT_DTYPE_MAPPING[pt_dtype]}) - ms_inputs_kwargs.update({"hidden_dtype": MS_DTYPE_MAPPING[ms_dtype]}) - - with torch.no_grad(): - pt_outputs = pt_model(*pt_inputs_args, **pt_inputs_kwargs) - ms_outputs = ms_model(*ms_inputs_args, **ms_inputs_kwargs) - - diffs = compute_diffs(pt_outputs, ms_outputs) - - assert ( - np.array(diffs) < THRESHOLD_FP16 - ).all(), f"Outputs({np.array(diffs).tolist()}) has diff bigger than {THRESHOLD_FP16}" - - -@pytest.mark.parametrize( - "name,pt_module,ms_module,init_args,init_kwargs,inputs_args,inputs_kwargs", - ALL_CASES, -) -def test_named_modules_with_pynative_fp32( - name, - pt_module, - ms_module, - init_args, - init_kwargs, - inputs_args, - inputs_kwargs, -): - dtype = "fp32" - ms.set_context(mode=ms.PYNATIVE_MODE, jit_syntax_level=ms.STRICT) - - ( - pt_model, - ms_model, - pt_dtype, - ms_dtype, - ) = get_modules(pt_module, ms_module, dtype, *init_args, **init_kwargs) - pt_inputs_args, pt_inputs_kwargs, ms_inputs_args, ms_inputs_kwargs = generalized_parse_args( - pt_dtype, ms_dtype, *inputs_args, **inputs_kwargs - ) - - with torch.no_grad(): - pt_outputs = pt_model(*pt_inputs_args, **pt_inputs_kwargs) - ms_outputs = ms_model(*ms_inputs_args, **ms_inputs_kwargs) - - diffs = compute_diffs(pt_outputs, ms_outputs) - - assert ( - np.array(diffs) < THRESHOLD_FP32 - ).all(), f"Outputs({np.array(diffs).tolist()}) has diff bigger than {THRESHOLD_FP32}" - - -@pytest.mark.parametrize( - "name,pt_module,ms_module,init_args,init_kwargs,inputs_args,inputs_kwargs", - ALL_CASES, -) -def test_named_modules_with_pynative_fp16( - name, - pt_module, - ms_module, - init_args, - init_kwargs, - inputs_args, - inputs_kwargs, -): - dtype = "fp16" - ms.set_context(mode=ms.PYNATIVE_MODE, jit_syntax_level=ms.STRICT) + ms.set_context(mode=mode, jit_syntax_level=ms.STRICT) ( pt_model, @@ -176,6 +62,8 @@ def test_named_modules_with_pynative_fp16( pt_dtype, ms_dtype, *inputs_args, **inputs_kwargs ) + # set `hidden_dtype` if requiring, for some modules always compute in float + # precision and require specific `hidden_dtype` to cast before return if "hidden_dtype" in inspect.signature(pt_model.forward).parameters: pt_inputs_kwargs.update({"hidden_dtype": PT_DTYPE_MAPPING[pt_dtype]}) ms_inputs_kwargs.update({"hidden_dtype": MS_DTYPE_MAPPING[ms_dtype]}) @@ -186,6 +74,5 @@ def test_named_modules_with_pynative_fp16( diffs = compute_diffs(pt_outputs, ms_outputs) - assert ( - np.array(diffs) < THRESHOLD_FP16 - ).all(), f"Outputs({np.array(diffs).tolist()}) has diff bigger than {THRESHOLD_FP16}" + THRESHOLD = THRESHOLD_FP32 if dtype == "fp32" else THRESHOLD_FP16 + assert (np.array(diffs) < THRESHOLD).all(), f"Outputs({np.array(diffs).tolist()}) has diff bigger than {THRESHOLD}" diff --git a/tests/diffusers/models/test_transformers.py b/tests/diffusers/models/test_t5_film_transformer.py similarity index 100% rename from tests/diffusers/models/test_transformers.py rename to tests/diffusers/models/test_t5_film_transformer.py