From 43b864cb146f7b4bc21a3a782b8761eb4b6c4be6 Mon Sep 17 00:00:00 2001 From: Chengsong Zhang Date: Mon, 16 Oct 2023 06:16:47 -0500 Subject: [PATCH 01/54] init --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 68ad9ea8..bd67969e 100644 --- a/README.md +++ b/README.md @@ -48,6 +48,7 @@ You might also be interested in another extension I created: [Segment Anything f - `2023/10/11`: [v1.9.1](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.9.1): Use state_dict key to guess mm version, replace match case with if else to support python<3.10, option to save PNG to custom dir (see `Settings/AnimateDiff` for detail), move hints to js, install imageio\[ffmpeg\] automatically when MP4 save fails. - `2023/10/16`: [v1.9.2](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.9.2): Add context generator to completely remove any closed loop, prompt travel support closed loop, infotext fully supported including prompt travel, README refactor +- `2023/10/??`: [v1.10.0](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.10.0): TODO ## How to Use From 1716e293aeda234c19d28e509b30b58dc16d77a5 Mon Sep 17 00:00:00 2001 From: Chengsong Zhang Date: Fri, 20 Oct 2023 20:34:24 -0500 Subject: [PATCH 02/54] remove timestep hack --- scripts/animatediff_mm.py | 39 ++++++++++++++++++++------------------- 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/scripts/animatediff_mm.py b/scripts/animatediff_mm.py index 230d21b1..6f61ea02 100644 --- a/scripts/animatediff_mm.py +++ b/scripts/animatediff_mm.py @@ -4,14 +4,15 @@ import torch from einops import rearrange -from ldm.modules.attention import SpatialTransformer -from ldm.modules.diffusionmodules.openaimodel import (TimestepBlock, - TimestepEmbedSequential) +# from ldm.modules.attention import SpatialTransformer +# from ldm.modules.diffusionmodules.openaimodel import (TimestepBlock, +# TimestepEmbedSequential) from ldm.modules.diffusionmodules.util import GroupNorm32 from modules import hashes, shared, sd_models from modules.devices import cpu, device, torch_gc -from motion_module import MotionWrapper, VanillaTemporalModule +from motion_module import MotionWrapper +# from motion_module import VanillaTemporalModule from scripts.animatediff_logger import logger_animatediff as logger @@ -24,7 +25,7 @@ def __init__(self): self.prev_alpha_cumprod = None self.prev_alpha_cumprod_prev = None self.gn32_original_forward = None - self.tes_original_forward = None + # self.tes_original_forward = None def set_script_dir(self, script_dir): @@ -75,20 +76,20 @@ def inject(self, sd_model, model_name="mm_sd_v15.ckpt"): unet = sd_model.model.diffusion_model self._load(model_name) self.gn32_original_forward = GroupNorm32.forward - self.tes_original_forward = TimestepEmbedSequential.forward gn32_original_forward = self.gn32_original_forward - - def mm_tes_forward(self, x, emb, context=None): - for layer in self: - if isinstance(layer, TimestepBlock): - x = layer(x, emb) - elif isinstance(layer, (SpatialTransformer, VanillaTemporalModule)): - x = layer(x, context) - else: - x = layer(x) - return x - - TimestepEmbedSequential.forward = mm_tes_forward + # self.tes_original_forward = TimestepEmbedSequential.forward + + # def mm_tes_forward(self, x, emb, context=None): + # for layer in self: + # if isinstance(layer, TimestepBlock): + # x = layer(x, emb) + # elif isinstance(layer, (SpatialTransformer, VanillaTemporalModule)): + # x = layer(x, context) + # else: + # x = layer(x) + # return x + + # TimestepEmbedSequential.forward = mm_tes_forward if self.mm.using_v2: logger.info(f"Injecting motion module {model_name} into SD1.5 UNet middle block.") unet.middle_block.insert(-1, self.mm.mid_block.motion_modules[0]) @@ -145,7 +146,7 @@ def restore(self, sd_model): else: logger.info(f"Restoring GroupNorm32 forward function.") GroupNorm32.forward = self.gn32_original_forward - TimestepEmbedSequential.forward = self.tes_original_forward + # TimestepEmbedSequential.forward = self.tes_original_forward logger.info(f"Removal finished.") if shared.cmd_opts.lowvram: self.unload() From cef5b99589c92f6abd3f355d657eba9b0552c398 Mon Sep 17 00:00:00 2001 From: Chengsong Zhang Date: Fri, 20 Oct 2023 21:06:03 -0500 Subject: [PATCH 03/54] add_reverse merge to closed_loop --- javascript/hints.js | 1 - scripts/animatediff_output.py | 13 ++++------- scripts/animatediff_ui.py | 43 +++++++++++++++++------------------ 3 files changed, 25 insertions(+), 32 deletions(-) diff --git a/javascript/hints.js b/javascript/hints.js index 9aa1bd4e..ce7d9542 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -8,7 +8,6 @@ if (titles) { // TODO: This is a conflict feature. If some other extension also // titles["Stride"] = "TODO"; // titles["Overlap"] = "TODO"; titles["Save format"] = "In which formats the animation should be saved."; - titles["Reverse"] = "Reverse the resulting animation, remove the first and/or last frame from duplication."; titles["Frame Interpolation"] = "Interpolate between frames with Deforum's FILM implementation. Requires Deforum extension."; titles["Interp X"] = "Replace each input frame with X interpolated output frames."; // titles["Video source"] = "TODO"; diff --git a/scripts/animatediff_output.py b/scripts/animatediff_output.py index 7515b6ab..d560195a 100644 --- a/scripts/animatediff_output.py +++ b/scripts/animatediff_output.py @@ -34,20 +34,15 @@ def output( video_list = self._interp(p, params, video_list, filename) video_paths += self._save(params, video_list, video_path_prefix, res, i) - if len(video_paths) > 0: - if not p.is_api: - res.images = video_paths - else: - # res.images = self._encode_video_to_b64(video_paths) - res.images = video_list + res.images = video_list if p.is_api else video_paths def _add_reverse(self, params: AnimateDiffProcess, video_list: list): - if 0 in params.reverse: + if params.video_length <= params.batch_size and params.closed_loop in ['A']: video_list_reverse = video_list[::-1] - if 1 in params.reverse: + if len(video_list_reverse) > 0: video_list_reverse.pop(0) - if 2 in params.reverse: + if len(video_list_reverse) > 0: video_list_reverse.pop(-1) return video_list + video_list_reverse return video_list diff --git a/scripts/animatediff_ui.py b/scripts/animatediff_ui.py index a292c058..2b7119c0 100644 --- a/scripts/animatediff_ui.py +++ b/scripts/animatediff_ui.py @@ -36,7 +36,6 @@ def __init__( format=["GIF", "PNG"], interp='Off', interp_x=10, - reverse=[], video_source=None, video_path='', latent_power=1, @@ -57,7 +56,6 @@ def __init__( self.format = format self.interp = interp self.interp_x = interp_x - self.reverse = reverse self.video_source = video_source self.video_path = video_path self.latent_power = latent_power @@ -75,19 +73,27 @@ def get_list(self, is_img2img: bool): def get_dict(self, is_img2img: bool): - dict_var = vars(self).copy() - dict_var["mm_hash"] = motion_module.mm.mm_hash[:8] - dict_var.pop("enable") - dict_var.pop("format") - dict_var.pop("video_source") - dict_var.pop("video_path") - dict_var.pop("last_frame") - if not is_img2img: - dict_var.pop("latent_power") - dict_var.pop("latent_scale") - dict_var.pop("latent_power_last") - dict_var.pop("latent_scale_last") - return dict_var + infotext = { + "mm_name": self.model, + "mm_hash": motion_module.mm.mm_hash[:8], + "video_length": self.video_length, + "fps": self.fps, + "loop_number": self.loop_number, + "closed_loop": self.closed_loop, + "batch_size": self.batch_size, + "stride": self.stride, + "overlap": self.overlap, + "interp": self.interp, + "interp_x": self.interp_x, + } + if is_img2img: + infotext.update({ + "latent_power": self.latent_power, + "latent_scale": self.latent_scale, + "latent_power_last": self.latent_power_last, + "latent_scale_last": self.latent_scale_last, + }) + return infotext def _check(self): @@ -215,13 +221,6 @@ def refresh_models(*inputs): elem_id=f"{elemid_prefix}save-format", value=self.params.format, ) - self.params.reverse = gr.CheckboxGroup( - choices=["Add Reverse Frame", "Remove head", "Remove tail"], - label="Reverse", - type="index", - elem_id=f"{elemid_prefix}reverse", - value=self.params.reverse - ) with gr.Row(): self.params.interp = gr.Radio( choices=["Off", "FILM"], From 537797f9a832ea5780527cc6ba100006e682fbe7 Mon Sep 17 00:00:00 2001 From: Chengsong Zhang Date: Fri, 20 Oct 2023 21:12:18 -0500 Subject: [PATCH 04/54] ui --- scripts/animatediff_ui.py | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/scripts/animatediff_ui.py b/scripts/animatediff_ui.py index 2b7119c0..f16a4a97 100644 --- a/scripts/animatediff_ui.py +++ b/scripts/animatediff_ui.py @@ -151,15 +151,24 @@ def refresh_models(*inputs): selected = None return gr.Dropdown.update(choices=new_model_list, value=selected) - self.params.model = gr.Dropdown( - choices=model_list, - value=(self.params.model if self.params.model in model_list else None), - label="Motion module", + with gr.Row(): + self.params.model = gr.Dropdown( + choices=model_list, + value=(self.params.model if self.params.model in model_list else None), + label="Motion module", + type="value", + elem_id=f"{elemid_prefix}motion-module", + ) + refresh_model = ToolButton(value="\U0001f504") + refresh_model.click(refresh_models, self.params.model, self.params.model) + + self.params.format = gr.CheckboxGroup( + choices=["GIF", "MP4", "WEBP", "PNG", "TXT"], + label="Save format", type="value", - elem_id=f"{elemid_prefix}motion-module", + elem_id=f"{elemid_prefix}save-format", + value=self.params.format, ) - refresh_model = ToolButton(value="\U0001f504") - refresh_model.click(refresh_models, self.params.model, self.params.model) with gr.Row(): self.params.enable = gr.Checkbox( value=self.params.enable, label="Enable AnimateDiff", @@ -213,14 +222,6 @@ def refresh_models(*inputs): precision=0, elem_id=f"{elemid_prefix}overlap", ) - with gr.Row(): - self.params.format = gr.CheckboxGroup( - choices=["GIF", "MP4", "WEBP", "PNG", "TXT"], - label="Save format", - type="value", - elem_id=f"{elemid_prefix}save-format", - value=self.params.format, - ) with gr.Row(): self.params.interp = gr.Radio( choices=["Off", "FILM"], From 06754129fd7dff4ae5a1a3788b0f2db5a48fc0a1 Mon Sep 17 00:00:00 2001 From: Chengsong Zhang Date: Fri, 20 Oct 2023 22:06:14 -0500 Subject: [PATCH 05/54] actually save prompt travel to infotext in output images --- scripts/animatediff.py | 9 ++++++++- scripts/animatediff_prompt.py | 7 ++++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/scripts/animatediff.py b/scripts/animatediff.py index c487badd..c82856a6 100644 --- a/scripts/animatediff.py +++ b/scripts/animatediff.py @@ -4,6 +4,7 @@ from modules import script_callbacks, scripts, shared from modules.processing import (Processed, StableDiffusionProcessing, StableDiffusionProcessingImg2Img) +from modules.scripts import PostprocessBatchListArgs from scripts.animatediff_cn import AnimateDiffControl from scripts.animatediff_infv2v import AnimateDiffInfV2V @@ -64,10 +65,16 @@ def before_process_batch(self, p: StableDiffusionProcessing, params: AnimateDiff AnimateDiffI2VLatent().randomize(p, params) + def postprocess_batch_list(self, p: StableDiffusionProcessing, pp: PostprocessBatchListArgs, params: AnimateDiffProcess, **kwargs): + if isinstance(params, dict): params = AnimateDiffProcess(**params) + if params.enable: + self.prompt_scheduler.save_infotext_img(p) + + def postprocess(self, p: StableDiffusionProcessing, res: Processed, params: AnimateDiffProcess): if isinstance(params, dict): params = AnimateDiffProcess(**params) if params.enable: - self.prompt_scheduler.set_infotext(res) + self.prompt_scheduler.save_infotext_txt(res) self.cn_hacker.restore() self.cfg_hacker.restore() self.lora_hacker.restore() diff --git a/scripts/animatediff_prompt.py b/scripts/animatediff_prompt.py index 91588ed4..bfa51960 100644 --- a/scripts/animatediff_prompt.py +++ b/scripts/animatediff_prompt.py @@ -14,7 +14,12 @@ def __init__(self): self.original_prompt = None - def set_infotext(self, res: Processed): + def save_infotext_img(self, p: StableDiffusionProcessing): + if self.prompt_map is not None: + p.prompts = [self.original_prompt for _ in range(p.batch_size)] + + + def save_infotext_txt(self, res: Processed): if self.prompt_map is not None: parts = res.info.split('\nNegative prompt: ', 1) if len(parts) > 1: From 8b28264f2b2170e73308104a0b468c23c7cb6497 Mon Sep 17 00:00:00 2001 From: Chengsong Zhang Date: Fri, 20 Oct 2023 22:50:45 -0500 Subject: [PATCH 06/54] remove hint.js, maintain doc in readme --- javascript/hints.js | 19 ------------------- scripts/animatediff_ui.py | 1 + 2 files changed, 1 insertion(+), 19 deletions(-) delete mode 100644 javascript/hints.js diff --git a/javascript/hints.js b/javascript/hints.js deleted file mode 100644 index ce7d9542..00000000 --- a/javascript/hints.js +++ /dev/null @@ -1,19 +0,0 @@ -if (titles) { // TODO: This is a conflict feature. If some other extension also have the same textContent, they will also show this hint, or substitite this hint. - titles["Motion module"] = "Choose which motion module to be injected into UNet."; - titles["Number of frames"] = "Total length of video in frames."; - titles["FPS"] = "How many frames per second generated GIF will be."; - titles["Display loop number"] = "How many times the animation will loop when displayed, a value of 0 will loop forever."; - // titles["Closed loop"] = "If enabled, will try to make the last frame the same as the first frame."; - // titles["Context batch size"] = "TODO"; - // titles["Stride"] = "TODO"; - // titles["Overlap"] = "TODO"; - titles["Save format"] = "In which formats the animation should be saved."; - titles["Frame Interpolation"] = "Interpolate between frames with Deforum's FILM implementation. Requires Deforum extension."; - titles["Interp X"] = "Replace each input frame with X interpolated output frames."; - // titles["Video source"] = "TODO"; - // titles["Video path"] = "TODO"; - // titles["Latent power"] = "TODO"; - // titles["Latent scale"] = "TODO"; - // titles["Optional latent power for last frame"] = "TODO"; - // titles["Optional latent scale for last frame"] = "TODO"; -} \ No newline at end of file diff --git a/scripts/animatediff_ui.py b/scripts/animatediff_ui.py index f16a4a97..d2b89900 100644 --- a/scripts/animatediff_ui.py +++ b/scripts/animatediff_ui.py @@ -136,6 +136,7 @@ def render(self, is_img2img: bool, model_dir: str): elemid_prefix = "img2img-ad-" if is_img2img else "txt2img-ad-" model_list = [f for f in os.listdir(model_dir) if f != ".gitkeep"] with gr.Accordion("AnimateDiff", open=False): + gr.Markdown(value="Please click [this link](https://github.com/continue-revolution/sd-webui-animatediff#webui-parameters) to read the documentation of each parameter.") with gr.Row(): def refresh_models(*inputs): From 7ad5a4ff79d6e654b11d131398543ca24b24506c Mon Sep 17 00:00:00 2001 From: Chengsong Zhang Date: Sat, 21 Oct 2023 00:09:06 -0500 Subject: [PATCH 07/54] infotext copy-paste seems like a non-trivial task. let's leave it to a future update --- scripts/animatediff_infotext.py | 19 +++++++++++++++++++ scripts/animatediff_ui.py | 6 ++++-- 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/scripts/animatediff_infotext.py b/scripts/animatediff_infotext.py index 211f0885..1a6c320f 100644 --- a/scripts/animatediff_infotext.py +++ b/scripts/animatediff_infotext.py @@ -4,6 +4,7 @@ from modules.processing import StableDiffusionProcessing, StableDiffusionProcessingImg2Img from scripts.animatediff_ui import AnimateDiffProcess +from scripts.animatediff_logger import logger_animatediff as logger def update_infotext(p: StableDiffusionProcessing, params: AnimateDiffProcess): @@ -14,3 +15,21 @@ def update_infotext(p: StableDiffusionProcessing, params: AnimateDiffProcess): def write_params_txt(info: str): with open(os.path.join(data_path, "params.txt"), "w", encoding="utf8") as file: file.write(info) + + + +def infotext_pasted(infotext, results): + for k, v in results.items(): + if not k.startswith("AnimateDiff"): + continue + + assert isinstance(v, str), f"Expect string but got {v}." + try: + for items in v.split(', '): + field, value = items.split(': ') + results[f"AnimateDiff {field}"] = value + except Exception: + logger.warn( + f"Failed to parse infotext, legacy format infotext is no longer supported:\n{v}" + ) + break diff --git a/scripts/animatediff_ui.py b/scripts/animatediff_ui.py index d2b89900..6ca2baac 100644 --- a/scripts/animatediff_ui.py +++ b/scripts/animatediff_ui.py @@ -74,7 +74,8 @@ def get_list(self, is_img2img: bool): def get_dict(self, is_img2img: bool): infotext = { - "mm_name": self.model, + "enable": self.enable, + "model": self.model, "mm_hash": motion_module.mm.mm_hash[:8], "video_length": self.video_length, "fps": self.fps, @@ -93,7 +94,8 @@ def get_dict(self, is_img2img: bool): "latent_power_last": self.latent_power_last, "latent_scale_last": self.latent_scale_last, }) - return infotext + infotext_str = ', '.join(f"{k}: {v}" for k, v in infotext.items()) + return infotext_str def _check(self): From ed76e5acc172487414e99af874b6380e9e2eebc6 Mon Sep 17 00:00:00 2001 From: Chengsong Zhang Date: Sat, 21 Oct 2023 01:17:35 -0500 Subject: [PATCH 08/54] readme explain --- README.md | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index b74ad127..a053a198 100644 --- a/README.md +++ b/README.md @@ -91,6 +91,7 @@ Just like how you use ControlNet. Here is a **stale** sample. (TODO: Update API ## WebUI Parameters +1. **Save format** — Format of the output. Choose at least one of "GIF"|"MP4"|"WEBP"|"PNG". Check "TXT" if you want infotext, which will live in the same directory as the output GIF. Infotext is also accessible via `stable-diffusion-webui/params.txt` and outputs in all formats. 1. **Number of frames** — Choose whatever number you like. If you enter 0 (default): @@ -101,17 +102,27 @@ Just like how you use ControlNet. Here is a **stale** sample. (TODO: Update API 1. **FPS** — Frames per second, which is how many frames (images) are shown every second. If 16 frames are generated at 8 frames per second, your GIF’s duration is 2 seconds. If you submit a source video, your FPS will be the same as the source video. 1. **Display loop number** — How many times the GIF is played. A value of `0` means the GIF never stops playing. 1. **Context batch size** — How many frames will be passed into the motion module at once. The model is trained with 16 frames, so it’ll give the best results when the number of frames is set to `16`. Choose [1, 24] for V1 motion modules and [1, 32] for V2 motion modules. -1. **Closed loop** — Closed loop means that this extension will try to make the last frame the same as the first frame. Closed loop can only be made possible when `Number of frames` is greater than `Context batch size`, including when ControlNet is enabled and the source video frame number is greater than `Context batch size` and `Number of frames` is 0. +1. **Closed loop** — Closed loop means that this extension will try to make the last frame the same as the first frame. + 1. When `Number of frames` > `Context batch size`, including when ControlNet is enabled and the source video frame number > `Context batch size` and `Number of frames` is 0, closed loop will be performed by AnimateDiff infinite context generator. + 1. When `Number of frames` <= `Context batch size`, AnimateDiff infinite context generator will not be effective. Only when you choose `A` will AnimateDiff append reversed list of frames to the original list of frames to form closed loop. + + See below for explanation of each choice: + - `N` means absolutely no closed loop - this is the only available option if `Number of frames` is smaller than `Context batch size` other than 0. - `R-P` means that the extension will try to reduce the number of closed loop context. The prompt travel will not be interpolated to be a closed loop. - `R+P` means that the extension will try to reduce the number of closed loop context. The prompt travel will be interpolated to be a closed loop. - `A` means that the extension will aggressively try to make the last frame the same as the first frame. The prompt travel will be interpolated to be a closed loop. -1. **Stride** — Max motion stride as a power of 2 (default: 1). TODO - Need more clear explanation. +1. **Stride** — Max motion stride as a power of 2 (default: 1). + 1. Due to the limitation of the infinite context generator, this parameter is effective only when `Number of frames` > `Context batch size`, including when ControlNet is enabled and the source video frame number > `Context batch size` and `Number of frames` is 0. + 1. "Absolutely no closed loop" is only possible when `Stride` is 1. + 1. For each 1 <= $2^i$ <= `Stride`, the infinite context generator will try to make frames $2^i$ apart temporal consistent. For example, if `Stride` is 4 and `Number of frames` is 8, it will make the following frames temporal consistent: + - `Stride` == 1: [0, 1, 2, 3, 4, 5, 6, 7, 8] + - `Stride` == 2: [0, 2, 4, 6, 8], [1, 3, 5, 7], [0, 4, 8] + - `Stride` == 4: [1, 5], [2, 6], [3, 7] 1. **Overlap** — Number of frames to overlap in context. If overlap is -1 (default): your overlap will be `Context batch size` // 4. -1. **Save format** — Format of the output. Choose at least one of "GIF"|"MP4"|"WEBP"|"PNG". Check "TXT" if you want infotext, which will live in the same directory as the output GIF. Infotext is also accessible via `stable-diffusion-webui/params.txt` and outputs in all formats. + 1. Due to the limitation of the infinite context generator, this parameter is effective only when `Number of frames` > `Context batch size`, including when ControlNet is enabled and the source video frame number > `Context batch size` and `Number of frames` is 0. 1. You can optimize GIF with `gifsicle` (`apt install gifsicle` required, read [#91](https://github.com/continue-revolution/sd-webui-animatediff/pull/91) for more information) and/or `palette` (read [#104](https://github.com/continue-revolution/sd-webui-animatediff/pull/104) for more information). Go to `Settings/AnimateDiff` to enable them. 2. You can set quality and lossless for WEBP via `Settings/AnimateDiff`. Read [#233](https://github.com/continue-revolution/sd-webui-animatediff/pull/233) for more information. -1. **Reverse** — Append reversed frames to your output. See [#112](https://github.com/continue-revolution/sd-webui-animatediff/issues/112) for instruction. 1. **Frame Interpolation** — Interpolate between frames with Deforum's FILM implementation. Requires Deforum extension. [#128](https://github.com/continue-revolution/sd-webui-animatediff/pull/128) 1. **Interp X** — Replace each input frame with X interpolated output frames. [#128](https://github.com/continue-revolution/sd-webui-animatediff/pull/128). 1. **Video source** — [Optional] Video source file for [ControlNet V2V](#controlnet-v2v). You MUST enable ControlNet. It will be the source control for ALL ControlNet units that you enable without submitting a control image or a path to ControlNet panel. You can of course submit one control image via `Single Image` tab or an input directory via `Batch` tab, which will override this video source input and work as usual. From 38a07c5b5dbb70503c95c6ef68fbffcd16204828 Mon Sep 17 00:00:00 2001 From: Chengsong Zhang Date: Sat, 21 Oct 2023 01:31:52 -0500 Subject: [PATCH 09/54] readme api --- README.md | 36 +++++++++++++++++++++--------------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index a053a198..c894f85d 100644 --- a/README.md +++ b/README.md @@ -49,7 +49,7 @@ You might also be interested in another extension I created: [Segment Anything f (see `Settings/AnimateDiff` for detail), move hints to js, install imageio\[ffmpeg\] automatically when MP4 save fails. - `2023/10/16`: [v1.9.2](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.9.2): Add context generator to completely remove any closed loop, prompt travel support closed loop, infotext fully supported including prompt travel, README refactor - `2023/10/19`: [v1.9.3](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.9.3): Support webp output format. See [#233](https://github.com/continue-revolution/sd-webui-animatediff/pull/233) for more information. -- `2023/10/??`: [v1.10.0](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.10.0): TODO +- `2023/10/??`: [v1.10.0](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.10.0): ? For future update plan, please query [here](https://github.com/continue-revolution/sd-webui-animatediff/pull/224). @@ -65,24 +65,30 @@ For future update plan, please query [here](https://github.com/continue-revoluti 1. You should see the output GIF on the output gallery. You can access GIF output at `stable-diffusion-webui/outputs/{txt2img or img2img}-images/AnimateDiff`. You can also access image frames at `stable-diffusion-webui/outputs/{txt2img or img2img}-images/{date}`. You may choose to save frames for each generation into separate directories in `Settings/AnimateDiff`. ### API -Just like how you use ControlNet. Here is a **stale** sample. (TODO: Update API Sample.) You will get a list of generated frames. You will have to view GIF in your file system, as mentioned at [WebUI](#webui) item 4. For most up-to-date parameters, please read [here](https://github.com/continue-revolution/sd-webui-animatediff/blob/master/scripts/animatediff_ui.py#L26). +Just like how you use ControlNet. Here is a sample. Due to the limitation of WebUI, you will not be able to get a video, but only a list of generated frames. You will have to view GIF in your file system, as mentioned at [WebUI](#webui) item 4. For most up-to-date parameters, please read [here](https://github.com/continue-revolution/sd-webui-animatediff/blob/master/scripts/animatediff_ui.py#L26). ``` 'alwayson_scripts': { 'AnimateDiff': { 'args': [{ - 'enable': True, # enable AnimateDiff - 'video_length': 16, # video frame number, 0-24 for v1 and 0-32 for v2 - 'format': ['GIF', 'PNG'], # 'GIF' | 'MP4' | 'PNG' | 'TXT' - 'loop_number': 0, # 0 = infinite loop - 'fps': 8, # frames per second - 'model': 'mm_sd_v15_v2.ckpt', # motion module name - 'reverse': [], # 0 | 1 | 2 - 0: Add Reverse Frame, 1: Remove head, 2: Remove tail - # parameters below are for img2gif only. - 'latent_power': 1, - 'latent_scale': 32, - 'last_frame': None, - 'latent_power_last': 1, - 'latent_scale_last': 32 + 'model': 'mm_sd_v15_v2.ckpt', # Motion module + 'format': ['GIF'], # Save format, 'GIF' | 'MP4' | 'PNG' | 'WEBP' | 'TXT' + 'enable': True, # Enable AnimateDiff + 'video_length': 16, # Number of frames + 'fps': 8, # FPS + 'loop_number': 0, # Display loop number + 'closed_loop': 'R+P', # Closed loop, 'N' | 'R-P' | 'R+P' | 'A' + 'batch_size': 16, # Context batch size + 'stride': 1, # Stride + 'overlap': -1, # Overlap + 'interp': 'Off', # Frame interpolation, 'Off' | 'FILM' + 'interl_x': 10 # Interp X + 'video_source': 'path/to/video.mp4', # Video source + 'video_path': 'path/to/frames', # Video path + 'latent_power': 1, # Latent power + 'latent_scale': 32, # Latent scale + 'last_frame': None, # Optional last frame + 'latent_power_last': 1, # Optional latent power for last frame + 'latent_scale_last': 32 # Optional latent scale for last frame } ] } From e30f2a1a970307c4045a6f99d5af61123f4b1b55 Mon Sep 17 00:00:00 2001 From: Chengsong Zhang Date: Sat, 21 Oct 2023 01:32:53 -0500 Subject: [PATCH 10/54] readme api --- README.md | 38 +++++++++++++++++++------------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/README.md b/README.md index c894f85d..0cc3b5ef 100644 --- a/README.md +++ b/README.md @@ -70,25 +70,25 @@ Just like how you use ControlNet. Here is a sample. Due to the limitation of Web 'alwayson_scripts': { 'AnimateDiff': { 'args': [{ - 'model': 'mm_sd_v15_v2.ckpt', # Motion module - 'format': ['GIF'], # Save format, 'GIF' | 'MP4' | 'PNG' | 'WEBP' | 'TXT' - 'enable': True, # Enable AnimateDiff - 'video_length': 16, # Number of frames - 'fps': 8, # FPS - 'loop_number': 0, # Display loop number - 'closed_loop': 'R+P', # Closed loop, 'N' | 'R-P' | 'R+P' | 'A' - 'batch_size': 16, # Context batch size - 'stride': 1, # Stride - 'overlap': -1, # Overlap - 'interp': 'Off', # Frame interpolation, 'Off' | 'FILM' - 'interl_x': 10 # Interp X - 'video_source': 'path/to/video.mp4', # Video source - 'video_path': 'path/to/frames', # Video path - 'latent_power': 1, # Latent power - 'latent_scale': 32, # Latent scale - 'last_frame': None, # Optional last frame - 'latent_power_last': 1, # Optional latent power for last frame - 'latent_scale_last': 32 # Optional latent scale for last frame + 'model': 'mm_sd_v15_v2.ckpt', # Motion module + 'format': ['GIF'], # Save format, 'GIF' | 'MP4' | 'PNG' | 'WEBP' | 'TXT' + 'enable': True, # Enable AnimateDiff + 'video_length': 16, # Number of frames + 'fps': 8, # FPS + 'loop_number': 0, # Display loop number + 'closed_loop': 'R+P', # Closed loop, 'N' | 'R-P' | 'R+P' | 'A' + 'batch_size': 16, # Context batch size + 'stride': 1, # Stride + 'overlap': -1, # Overlap + 'interp': 'Off', # Frame interpolation, 'Off' | 'FILM' + 'interl_x': 10 # Interp X + 'video_source': 'path/to/video.mp4', # Video source + 'video_path': 'path/to/frames', # Video path + 'latent_power': 1, # Latent power + 'latent_scale': 32, # Latent scale + 'last_frame': None, # Optional last frame + 'latent_power_last': 1, # Optional latent power for last frame + 'latent_scale_last': 32 # Optional latent scale for last frame } ] } From 1474dd51b4ed7404d9ff6c110546e81eb6d7f030 Mon Sep 17 00:00:00 2001 From: Chengsong Zhang Date: Sat, 21 Oct 2023 01:33:16 -0500 Subject: [PATCH 11/54] readme api --- README.md | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 0cc3b5ef..98b4c2de 100644 --- a/README.md +++ b/README.md @@ -76,14 +76,14 @@ Just like how you use ControlNet. Here is a sample. Due to the limitation of Web 'video_length': 16, # Number of frames 'fps': 8, # FPS 'loop_number': 0, # Display loop number - 'closed_loop': 'R+P', # Closed loop, 'N' | 'R-P' | 'R+P' | 'A' - 'batch_size': 16, # Context batch size - 'stride': 1, # Stride - 'overlap': -1, # Overlap - 'interp': 'Off', # Frame interpolation, 'Off' | 'FILM' - 'interl_x': 10 # Interp X - 'video_source': 'path/to/video.mp4', # Video source - 'video_path': 'path/to/frames', # Video path + 'closed_loop': 'R+P', # Closed loop, 'N' | 'R-P' | 'R+P' | 'A' + 'batch_size': 16, # Context batch size + 'stride': 1, # Stride + 'overlap': -1, # Overlap + 'interp': 'Off', # Frame interpolation, 'Off' | 'FILM' + 'interl_x': 10 # Interp X + 'video_source': 'path/to/video.mp4', # Video source + 'video_path': 'path/to/frames', # Video path 'latent_power': 1, # Latent power 'latent_scale': 32, # Latent scale 'last_frame': None, # Optional last frame From c2ad1d22be8d4149005f8fa16616c6def4594d04 Mon Sep 17 00:00:00 2001 From: Chengsong Zhang Date: Sat, 21 Oct 2023 01:33:46 -0500 Subject: [PATCH 12/54] readme api --- README.md | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index 98b4c2de..6d631bd6 100644 --- a/README.md +++ b/README.md @@ -69,13 +69,13 @@ Just like how you use ControlNet. Here is a sample. Due to the limitation of Web ``` 'alwayson_scripts': { 'AnimateDiff': { - 'args': [{ - 'model': 'mm_sd_v15_v2.ckpt', # Motion module - 'format': ['GIF'], # Save format, 'GIF' | 'MP4' | 'PNG' | 'WEBP' | 'TXT' - 'enable': True, # Enable AnimateDiff - 'video_length': 16, # Number of frames - 'fps': 8, # FPS - 'loop_number': 0, # Display loop number + 'args': [{ + 'model': 'mm_sd_v15_v2.ckpt', # Motion module + 'format': ['GIF'], # Save format, 'GIF' | 'MP4' | 'PNG' | 'WEBP' | 'TXT' + 'enable': True, # Enable AnimateDiff + 'video_length': 16, # Number of frames + 'fps': 8, # FPS + 'loop_number': 0, # Display loop number 'closed_loop': 'R+P', # Closed loop, 'N' | 'R-P' | 'R+P' | 'A' 'batch_size': 16, # Context batch size 'stride': 1, # Stride @@ -84,13 +84,13 @@ Just like how you use ControlNet. Here is a sample. Due to the limitation of Web 'interl_x': 10 # Interp X 'video_source': 'path/to/video.mp4', # Video source 'video_path': 'path/to/frames', # Video path - 'latent_power': 1, # Latent power - 'latent_scale': 32, # Latent scale - 'last_frame': None, # Optional last frame - 'latent_power_last': 1, # Optional latent power for last frame - 'latent_scale_last': 32 # Optional latent scale for last frame - } - ] + 'latent_power': 1, # Latent power + 'latent_scale': 32, # Latent scale + 'last_frame': None, # Optional last frame + 'latent_power_last': 1, # Optional latent power for last frame + 'latent_scale_last': 32 # Optional latent scale for last frame + } + ] } }, ``` From 3579576a476a8a6e041c44a380c2b6bd8d87c304 Mon Sep 17 00:00:00 2001 From: Chengsong Zhang Date: Sat, 21 Oct 2023 01:34:04 -0500 Subject: [PATCH 13/54] readme api --- README.md | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 6d631bd6..1b9f62cc 100644 --- a/README.md +++ b/README.md @@ -76,14 +76,14 @@ Just like how you use ControlNet. Here is a sample. Due to the limitation of Web 'video_length': 16, # Number of frames 'fps': 8, # FPS 'loop_number': 0, # Display loop number - 'closed_loop': 'R+P', # Closed loop, 'N' | 'R-P' | 'R+P' | 'A' - 'batch_size': 16, # Context batch size - 'stride': 1, # Stride - 'overlap': -1, # Overlap - 'interp': 'Off', # Frame interpolation, 'Off' | 'FILM' - 'interl_x': 10 # Interp X - 'video_source': 'path/to/video.mp4', # Video source - 'video_path': 'path/to/frames', # Video path + 'closed_loop': 'R+P', # Closed loop, 'N' | 'R-P' | 'R+P' | 'A' + 'batch_size': 16, # Context batch size + 'stride': 1, # Stride + 'overlap': -1, # Overlap + 'interp': 'Off', # Frame interpolation, 'Off' | 'FILM' + 'interl_x': 10 # Interp X + 'video_source': 'path/to/video.mp4', # Video source + 'video_path': 'path/to/frames', # Video path 'latent_power': 1, # Latent power 'latent_scale': 32, # Latent scale 'last_frame': None, # Optional last frame From 180f60c2ce54232d8635d81839229110cf32df22 Mon Sep 17 00:00:00 2001 From: Chengsong Zhang Date: Sat, 21 Oct 2023 01:34:29 -0500 Subject: [PATCH 14/54] readme api --- README.md | 44 ++++++++++++++++++++++---------------------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/README.md b/README.md index 1b9f62cc..ef9f76e9 100644 --- a/README.md +++ b/README.md @@ -69,28 +69,28 @@ Just like how you use ControlNet. Here is a sample. Due to the limitation of Web ``` 'alwayson_scripts': { 'AnimateDiff': { - 'args': [{ - 'model': 'mm_sd_v15_v2.ckpt', # Motion module - 'format': ['GIF'], # Save format, 'GIF' | 'MP4' | 'PNG' | 'WEBP' | 'TXT' - 'enable': True, # Enable AnimateDiff - 'video_length': 16, # Number of frames - 'fps': 8, # FPS - 'loop_number': 0, # Display loop number - 'closed_loop': 'R+P', # Closed loop, 'N' | 'R-P' | 'R+P' | 'A' - 'batch_size': 16, # Context batch size - 'stride': 1, # Stride - 'overlap': -1, # Overlap - 'interp': 'Off', # Frame interpolation, 'Off' | 'FILM' - 'interl_x': 10 # Interp X - 'video_source': 'path/to/video.mp4', # Video source - 'video_path': 'path/to/frames', # Video path - 'latent_power': 1, # Latent power - 'latent_scale': 32, # Latent scale - 'last_frame': None, # Optional last frame - 'latent_power_last': 1, # Optional latent power for last frame - 'latent_scale_last': 32 # Optional latent scale for last frame - } - ] + 'args': [{ + 'model': 'mm_sd_v15_v2.ckpt', # Motion module + 'format': ['GIF'], # Save format, 'GIF' | 'MP4' | 'PNG' | 'WEBP' | 'TXT' + 'enable': True, # Enable AnimateDiff + 'video_length': 16, # Number of frames + 'fps': 8, # FPS + 'loop_number': 0, # Display loop number + 'closed_loop': 'R+P', # Closed loop, 'N' | 'R-P' | 'R+P' | 'A' + 'batch_size': 16, # Context batch size + 'stride': 1, # Stride + 'overlap': -1, # Overlap + 'interp': 'Off', # Frame interpolation, 'Off' | 'FILM' + 'interl_x': 10 # Interp X + 'video_source': 'path/to/video.mp4', # Video source + 'video_path': 'path/to/frames', # Video path + 'latent_power': 1, # Latent power + 'latent_scale': 32, # Latent scale + 'last_frame': None, # Optional last frame + 'latent_power_last': 1, # Optional latent power for last frame + 'latent_scale_last': 32 # Optional latent scale for last frame + } + ] } }, ``` From 2d8d2c7cfe7f3c40f0abe1af280050db82023d46 Mon Sep 17 00:00:00 2001 From: Chengsong Zhang Date: Tue, 24 Oct 2023 04:03:51 -0500 Subject: [PATCH 15/54] first step to add a bunch of code, need to change p.init and cn.main afterwards --- scripts/animatediff_cn.py | 311 ++++++++++++++++++++++++++++++++++---- 1 file changed, 281 insertions(+), 30 deletions(-) diff --git a/scripts/animatediff_cn.py b/scripts/animatediff_cn.py index ffd66f6e..dc3b5b41 100644 --- a/scripts/animatediff_cn.py +++ b/scripts/animatediff_cn.py @@ -1,15 +1,27 @@ from pathlib import Path +from types import MethodType +from typing import Optional import os import shutil import cv2 import numpy as np import torch -from modules import processing, shared +import hashlib +from PIL import Image, ImageFilter, ImageOps, UnidentifiedImageError +from modules import processing, shared, scripts, img2img, devices, masking, sd_samplers, images from modules.paths import data_path from modules.processing import (StableDiffusionProcessing, StableDiffusionProcessingImg2Img, - StableDiffusionProcessingTxt2Img) + StableDiffusionProcessingTxt2Img, + process_images, + create_binary_mask, + create_random_tensors, + images_tensor_to_samples, + setup_color_correction, + opt_f) +from modules.shared import opts +from modules.sd_samplers_common import images_tensor_to_samples, approximation_indexes from scripts.animatediff_logger import logger_animatediff as logger from scripts.animatediff_ui import AnimateDiffProcess @@ -21,6 +33,7 @@ class AnimateDiffControl: def __init__(self, p: StableDiffusionProcessing, prompt_scheduler: AnimateDiffPromptSchedule): self.original_processing_process_images_hijack = None + self.original_img2img_process_batch_hijack = None self.original_controlnet_main_entry = None self.original_postprocess_batch = None try: @@ -53,12 +66,10 @@ def get_input_frames(): return params.video_path return '' - from scripts import external_code - from scripts.batch_hijack import InputMode, BatchHijack, instance + from scripts.batch_hijack import BatchHijack, instance def hacked_processing_process_images_hijack(self, p: StableDiffusionProcessing, *args, **kwargs): - if self.is_batch: # AnimateDiff does not support this. - # we are in img2img batch tab, do a single batch iteration - return self.process_images_cn_batch(p, *args, **kwargs) + from scripts import external_code + from scripts.batch_hijack import InputMode units = external_code.get_all_units_in_processing(p) units = [unit for unit in units if getattr(unit, 'enabled', False)] @@ -66,12 +77,15 @@ def hacked_processing_process_images_hijack(self, p: StableDiffusionProcessing, if len(units) > 0: global_input_frames = get_input_frames() for idx, unit in enumerate(units): + # i2i-batch mode + if getattr(p, '_animatediff_i2i_batch', None) and not unit.image: + unit.input_mode = InputMode.BATCH # if no input given for this unit, use global input if getattr(unit, 'input_mode', InputMode.SIMPLE) == InputMode.BATCH: - if not (isinstance(unit.batch_images, str) and unit.batch_images != ''): - assert global_input_frames != '', 'No input images found for ControlNet module' + if not unit.batch_images: + assert global_input_frames, 'No input images found for ControlNet module' unit.batch_images = global_input_frames - elif unit.image is None: + elif not unit.image: try: cn_script.choose_input_image(p, unit, idx) except: @@ -89,7 +103,9 @@ def hacked_processing_process_images_hijack(self, p: StableDiffusionProcessing, unit.batch_images = shared.listfiles(unit.batch_images) unit_batch_list = [len(unit.batch_images) for unit in units - if getattr(unit, 'input_mode', InputMode.SIMPLE) == InputMode.BATCH] + if getattr(unit, 'input_mode', InputMode.SIMPLE) == InputMode.BATCH] + if getattr(p, '_animatediff_i2i_batch', None): + unit_batch_list.append(len(p.init_images)) if len(unit_batch_list) > 0: video_length = min(unit_batch_list) @@ -112,36 +128,271 @@ def hacked_processing_process_images_hijack(self, p: StableDiffusionProcessing, self.original_processing_process_images_hijack = BatchHijack.processing_process_images_hijack BatchHijack.processing_process_images_hijack = hacked_processing_process_images_hijack processing.process_images_inner = instance.processing_process_images_hijack - + + def hacked_i2i_init(self, all_prompts, all_seeds, all_subseeds): # only hack this when i2i-batch with batch mask + # TODO: hack this! + self.image_cfg_scale: float = self.image_cfg_scale if shared.sd_model.cond_stage_key == "edit" else None + + self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model) + crop_region = None + + image_mask = self.image_mask + + if image_mask is not None: + # image_mask is passed in as RGBA by Gradio to support alpha masks, + # but we still want to support binary masks. + image_mask = create_binary_mask(image_mask) + + if self.inpainting_mask_invert: + image_mask = ImageOps.invert(image_mask) + + if self.mask_blur_x > 0: + np_mask = np.array(image_mask) + kernel_size = 2 * int(2.5 * self.mask_blur_x + 0.5) + 1 + np_mask = cv2.GaussianBlur(np_mask, (kernel_size, 1), self.mask_blur_x) + image_mask = Image.fromarray(np_mask) + + if self.mask_blur_y > 0: + np_mask = np.array(image_mask) + kernel_size = 2 * int(2.5 * self.mask_blur_y + 0.5) + 1 + np_mask = cv2.GaussianBlur(np_mask, (1, kernel_size), self.mask_blur_y) + image_mask = Image.fromarray(np_mask) + + if self.inpaint_full_res: + self.mask_for_overlay = image_mask + mask = image_mask.convert('L') + crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding) + crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height) + x1, y1, x2, y2 = crop_region + + mask = mask.crop(crop_region) + image_mask = images.resize_image(2, mask, self.width, self.height) + self.paste_to = (x1, y1, x2-x1, y2-y1) + else: + image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height) + np_mask = np.array(image_mask) + np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8) + self.mask_for_overlay = Image.fromarray(np_mask) + + self.overlay_images = [] + + latent_mask = self.latent_mask if self.latent_mask is not None else image_mask + + add_color_corrections = opts.img2img_color_correction and self.color_corrections is None + if add_color_corrections: + self.color_corrections = [] + imgs = [] + for img in self.init_images: + + # Save init image + if opts.save_init_img: + self.init_img_hash = hashlib.md5(img.tobytes()).hexdigest() + images.save_image(img, path=opts.outdir_init_images, basename=None, forced_filename=self.init_img_hash, save_to_dirs=False) + + image = images.flatten(img, opts.img2img_background_color) + + if crop_region is None and self.resize_mode != 3: + image = images.resize_image(self.resize_mode, image, self.width, self.height) + + if image_mask is not None: + image_masked = Image.new('RGBa', (image.width, image.height)) + image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L'))) + + self.overlay_images.append(image_masked.convert('RGBA')) + + # crop_region is not None if we are doing inpaint full res + if crop_region is not None: + image = image.crop(crop_region) + image = images.resize_image(2, image, self.width, self.height) + + if image_mask is not None: + if self.inpainting_fill != 1: + image = masking.fill(image, latent_mask) + + if add_color_corrections: + self.color_corrections.append(setup_color_correction(image)) + + image = np.array(image).astype(np.float32) / 255.0 + image = np.moveaxis(image, 2, 0) + + imgs.append(image) + + if len(imgs) == 1: + batch_images = np.expand_dims(imgs[0], axis=0).repeat(self.batch_size, axis=0) + if self.overlay_images is not None: + self.overlay_images = self.overlay_images * self.batch_size + + if self.color_corrections is not None and len(self.color_corrections) == 1: + self.color_corrections = self.color_corrections * self.batch_size + + elif len(imgs) <= self.batch_size: + self.batch_size = len(imgs) + batch_images = np.array(imgs) + else: + raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less") + + image = torch.from_numpy(batch_images) + image = image.to(shared.device, dtype=devices.dtype_vae) + + if opts.sd_vae_encode_method != 'Full': + self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method + + self.init_latent = images_tensor_to_samples(image, approximation_indexes.get(opts.sd_vae_encode_method), self.sd_model) + devices.torch_gc() + + if self.resize_mode == 3: + self.init_latent = torch.nn.functional.interpolate(self.init_latent, size=(self.height // opt_f, self.width // opt_f), mode="bilinear") + + if image_mask is not None: + init_mask = latent_mask + latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2])) + latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255 + latmask = latmask[0] + latmask = np.around(latmask) + latmask = np.tile(latmask[None], (4, 1, 1)) + + self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(self.sd_model.dtype) + self.nmask = torch.asarray(latmask).to(shared.device).type(self.sd_model.dtype) + + # this needs to be fixed to be done in sample() using actual seeds for batches + if self.inpainting_fill == 2: + self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], all_seeds[0:self.init_latent.shape[0]]) * self.nmask + elif self.inpainting_fill == 3: + self.init_latent = self.init_latent * self.mask + + self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, self.init_latent, image_mask) + + def hacked_img2img_process_batch_hijack( + self, p: StableDiffusionProcessingImg2Img, input_dir: str, output_dir: str, inpaint_mask_dir: str, + args, to_scale=False, scale_by=1.0, use_png_info=False, png_info_props=None, png_info_dir=None): + p._animatediff_i2i_batch = 1 # i2i-batch mode, ordinary + output_dir = output_dir.strip() + processing.fix_seed(p) + + images = list(shared.walk_files(input_dir, allowed_extensions=(".png", ".jpg", ".jpeg", ".webp", ".tif", ".tiff"))) + + is_inpaint_batch = False + if inpaint_mask_dir: + inpaint_masks = shared.listfiles(inpaint_mask_dir) + is_inpaint_batch = bool(inpaint_masks) + + if is_inpaint_batch: + assert len(inpaint_masks) == 1 or len(inpaint_masks) == len(images), 'The number of masks must be 1 or equal to the number of images.' + logger.info(f"\n[i2i batch] Inpaint batch is enabled. {len(inpaint_masks)} masks found.") + if len(inpaint_masks) > 1: # batch mask + p.init = MethodType(hacked_i2i_init, p) + + logger.info(f"[i2i batch] Will process {len(images)} images, creating {p.n_iter} new videos.") + + # extract "default" params to use in case getting png info fails + prompt = p.prompt + negative_prompt = p.negative_prompt + seed = p.seed + cfg_scale = p.cfg_scale + sampler_name = p.sampler_name + steps = p.steps + frame_images = [] + frame_masks = [] + + for i, image in enumerate(images): + + try: + img = Image.open(image) + except UnidentifiedImageError as e: + print(e) + continue + # Use the EXIF orientation of photos taken by smartphones. + img = ImageOps.exif_transpose(img) + + if to_scale: + p.width = int(img.width * scale_by) + p.height = int(img.height * scale_by) + + frame_images.append(img) + + image_path = Path(image) + if is_inpaint_batch: + if len(inpaint_masks) == 1: + mask_image_path = inpaint_masks[0] + p.image_mask = Image.open(mask_image_path) + else: + # try to find corresponding mask for an image using index matching + mask_image_path = inpaint_masks[i] + frame_masks.append(Image.open(mask_image_path)) + + mask_image = Image.open(mask_image_path) + p.image_mask = mask_image + + if use_png_info: + try: + info_img = frame_images[0] + if png_info_dir: + info_img_path = os.path.join(png_info_dir, os.path.basename(image)) + info_img = Image.open(info_img_path) + from modules import images as imgutil + from modules.generation_parameters_copypaste import parse_generation_parameters + geninfo, _ = imgutil.read_info_from_image(info_img) + parsed_parameters = parse_generation_parameters(geninfo) + parsed_parameters = {k: v for k, v in parsed_parameters.items() if k in (png_info_props or {})} + except Exception: + parsed_parameters = {} + + p.prompt = prompt + (" " + parsed_parameters["Prompt"] if "Prompt" in parsed_parameters else "") + p.negative_prompt = negative_prompt + (" " + parsed_parameters["Negative prompt"] if "Negative prompt" in parsed_parameters else "") + p.seed = int(parsed_parameters.get("Seed", seed)) + p.cfg_scale = float(parsed_parameters.get("CFG scale", cfg_scale)) + p.sampler_name = parsed_parameters.get("Sampler", sampler_name) + p.steps = int(parsed_parameters.get("Steps", steps)) + + p.init_images = frame_images + if len(frame_masks) > 0: + p.image_mask = frame_masks + + proc = scripts.scripts_img2img.run(p, *args) # we should not support this, but just leave it here + if proc is None: + if output_dir: + p.outpath_samples = output_dir + p.override_settings['save_to_dirs'] = False + if p.n_iter > 1 or p.batch_size > 1: + p.override_settings['samples_filename_pattern'] = f'{image_path.stem}-[generation_number]' + else: + p.override_settings['samples_filename_pattern'] = f'{image_path.stem}' + process_images(p) + else: + logger.warn("Warning: you are using an unsupported external script. AnimateDiff may not work properly.") + + self.original_img2img_process_batch_hijack = BatchHijack.img2img_process_batch_hijack + BatchHijack.img2img_process_batch_hijack = hacked_img2img_process_batch_hijack + img2img.process_batch = instance.img2img_process_batch_hijack + def restore_batchhijack(self): from scripts.batch_hijack import BatchHijack, instance BatchHijack.processing_process_images_hijack = self.original_processing_process_images_hijack self.original_processing_process_images_hijack = None processing.process_images_inner = instance.processing_process_images_hijack + BatchHijack.img2img_process_batch_hijack = self.original_img2img_process_batch_hijack + self.original_img2img_process_batch_hijack = None + img2img.process_batch = instance.img2img_process_batch_hijack def hack_cn(self): cn_script = self.cn_script - from types import MethodType - from typing import Optional - - from modules import images, masking - from PIL import Image, ImageFilter, ImageOps - - from scripts import external_code, global_state, hook - # from scripts.controlnet_lora import bind_control_lora # do not support control lora for sdxl - from scripts.adapter import Adapter, Adapter_light, StyleAdapter - from scripts.batch_hijack import InputMode - # from scripts.controlnet_lllite import PlugableControlLLLite, clear_all_lllite # do not support controlllite for sdxl - from scripts.controlmodel_ipadapter import (PlugableIPAdapter, - clear_all_ip_adapter) - from scripts.hook import ControlModelType, ControlParams, UnetHook - from scripts.logging import logger - from scripts.processor import model_free_preprocessors def hacked_main_entry(self, p: StableDiffusionProcessing): + from scripts import external_code, global_state, hook + # from scripts.controlnet_lora import bind_control_lora # do not support control lora for sdxl + from scripts.adapter import Adapter, Adapter_light, StyleAdapter + from scripts.batch_hijack import InputMode + # from scripts.controlnet_lllite import PlugableControlLLLite, clear_all_lllite # do not support controlllite for sdxl + from scripts.controlmodel_ipadapter import (PlugableIPAdapter, + clear_all_ip_adapter) + from scripts.hook import ControlModelType, ControlParams, UnetHook + from scripts.logging import logger + from scripts.processor import model_free_preprocessors + + # TODO: i2i-batch mode, what should I change? def image_has_mask(input_image: np.ndarray) -> bool: return ( input_image.ndim == 3 and @@ -585,14 +836,14 @@ def recolor_intensity_post_processing(x, i): if os.path.exists(f'{data_path}/tmp/animatediff-frames/'): shutil.rmtree(f'{data_path}/tmp/animatediff-frames/') - + def hacked_postprocess_batch(self, p, *args, **kwargs): images = kwargs.get('images', []) for post_processor in self.post_processors: for i in range(len(images)): images[i] = post_processor(images[i], i) return - + self.original_controlnet_main_entry = self.cn_script.controlnet_main_entry self.original_postprocess_batch = self.cn_script.postprocess_batch self.cn_script.controlnet_main_entry = MethodType(hacked_main_entry, self.cn_script) From 0947b9b8f94c31043d90575f0e4992244faa5327 Mon Sep 17 00:00:00 2001 From: Chengsong Zhang Date: Wed, 25 Oct 2023 01:45:01 -0500 Subject: [PATCH 16/54] finish i2i-batch --- scripts/animatediff.py | 8 ++++- scripts/animatediff_cn.py | 70 ++++++++++++++++++++++++--------------- 2 files changed, 51 insertions(+), 27 deletions(-) diff --git a/scripts/animatediff.py b/scripts/animatediff.py index c82856a6..a778fed3 100644 --- a/scripts/animatediff.py +++ b/scripts/animatediff.py @@ -4,7 +4,7 @@ from modules import script_callbacks, scripts, shared from modules.processing import (Processed, StableDiffusionProcessing, StableDiffusionProcessingImg2Img) -from modules.scripts import PostprocessBatchListArgs +from modules.scripts import PostprocessBatchListArgs, PostprocessImageArgs from scripts.animatediff_cn import AnimateDiffControl from scripts.animatediff_infv2v import AnimateDiffInfV2V @@ -71,6 +71,12 @@ def postprocess_batch_list(self, p: StableDiffusionProcessing, pp: PostprocessBa self.prompt_scheduler.save_infotext_img(p) + def postprocess_image(self, p: StableDiffusionProcessing, pp: PostprocessImageArgs, params: AnimateDiffProcess, *args): + if isinstance(params, dict): params = AnimateDiffProcess(**params) + if params.enable and isinstance(p, StableDiffusionProcessingImg2Img) and hasattr(p, '_animatediff_paste_to_full'): + p.paste_to = p._animatediff_paste_to_full[p.batch_index] + + def postprocess(self, p: StableDiffusionProcessing, res: Processed, params: AnimateDiffProcess): if isinstance(params, dict): params = AnimateDiffProcess(**params) if params.enable: diff --git a/scripts/animatediff_cn.py b/scripts/animatediff_cn.py index dc3b5b41..90fbe0a8 100644 --- a/scripts/animatediff_cn.py +++ b/scripts/animatediff_cn.py @@ -130,15 +130,16 @@ def hacked_processing_process_images_hijack(self, p: StableDiffusionProcessing, processing.process_images_inner = instance.processing_process_images_hijack def hacked_i2i_init(self, all_prompts, all_seeds, all_subseeds): # only hack this when i2i-batch with batch mask - # TODO: hack this! self.image_cfg_scale: float = self.image_cfg_scale if shared.sd_model.cond_stage_key == "edit" else None self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model) - crop_region = None + crop_regions = [] + paste_to = [] + masks_for_overlay = [] - image_mask = self.image_mask + image_masks = self.image_mask - if image_mask is not None: + for idx, image_mask in enumerate(image_masks): # image_mask is passed in as RGBA by Gradio to support alpha masks, # but we still want to support binary masks. image_mask = create_binary_mask(image_mask) @@ -159,31 +160,36 @@ def hacked_i2i_init(self, all_prompts, all_seeds, all_subseeds): # only hack thi image_mask = Image.fromarray(np_mask) if self.inpaint_full_res: - self.mask_for_overlay = image_mask + masks_for_overlay.append(image_mask) mask = image_mask.convert('L') crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding) crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height) + crop_regions.append(crop_region) x1, y1, x2, y2 = crop_region mask = mask.crop(crop_region) image_mask = images.resize_image(2, mask, self.width, self.height) - self.paste_to = (x1, y1, x2-x1, y2-y1) + paste_to.append((x1, y1, x2-x1, y2-y1)) else: image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height) np_mask = np.array(image_mask) np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8) - self.mask_for_overlay = Image.fromarray(np_mask) + masks_for_overlay.append(Image.fromarray(np_mask)) - self.overlay_images = [] + image_masks[idx] = image_mask - latent_mask = self.latent_mask if self.latent_mask is not None else image_mask + self.mask_for_overlay = masks_for_overlay[0] # only for saving purpose + if paste_to: + self.paste_to = paste_to[0] + self._animatediff_paste_to_full = paste_to + self.overlay_images = [] add_color_corrections = opts.img2img_color_correction and self.color_corrections is None if add_color_corrections: self.color_corrections = [] imgs = [] - for img in self.init_images: - + for idx, img in enumerate(self.init_images): + latent_mask = (self.latent_mask[idx] if isinstance(self.latent_mask, list) else self.latent_mask) if self.latent_mask is not None else image_masks[idx] # Save init image if opts.save_init_img: self.init_img_hash = hashlib.md5(img.tobytes()).hexdigest() @@ -191,21 +197,21 @@ def hacked_i2i_init(self, all_prompts, all_seeds, all_subseeds): # only hack thi image = images.flatten(img, opts.img2img_background_color) - if crop_region is None and self.resize_mode != 3: + if not crop_regions and self.resize_mode != 3: image = images.resize_image(self.resize_mode, image, self.width, self.height) - if image_mask is not None: + if image_masks: image_masked = Image.new('RGBa', (image.width, image.height)) - image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L'))) + image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(masks_for_overlay[idx].convert('L'))) self.overlay_images.append(image_masked.convert('RGBA')) # crop_region is not None if we are doing inpaint full res - if crop_region is not None: - image = image.crop(crop_region) + if crop_regions: + image = image.crop(crop_regions[idx]) image = images.resize_image(2, image, self.width, self.height) - if image_mask is not None: + if image_masks: if self.inpainting_fill != 1: image = masking.fill(image, latent_mask) @@ -243,13 +249,23 @@ def hacked_i2i_init(self, all_prompts, all_seeds, all_subseeds): # only hack thi if self.resize_mode == 3: self.init_latent = torch.nn.functional.interpolate(self.init_latent, size=(self.height // opt_f, self.width // opt_f), mode="bilinear") - if image_mask is not None: - init_mask = latent_mask - latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2])) - latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255 - latmask = latmask[0] - latmask = np.around(latmask) - latmask = np.tile(latmask[None], (4, 1, 1)) + if image_masks is not None: + def process_letmask(init_mask): + # init_mask = latent_mask + latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2])) + latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255 + latmask = latmask[0] + latmask = np.around(latmask) + return np.tile(latmask[None], (4, 1, 1)) + + if self.latent_mask is not None and not isinstance(self.latent_mask, list): + latmask = process_letmask(self.latent_mask) + elif self.latent_mask is not None: + if isinstance(self.latent_mask, list): + latmask = [process_letmask(x) for x in self.latent_mask] + else: + latmask = [process_letmask(x) for x in image_masks] + latmask = np.stack(latmask, axis=0) self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(self.sd_model.dtype) self.nmask = torch.asarray(latmask).to(shared.device).type(self.sd_model.dtype) @@ -260,7 +276,7 @@ def hacked_i2i_init(self, all_prompts, all_seeds, all_subseeds): # only hack thi elif self.inpainting_fill == 3: self.init_latent = self.init_latent * self.mask - self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, self.init_latent, image_mask) + self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, self.init_latent, image_masks) # let's ignore this image_masks which is related to inpaint model with different arch def hacked_img2img_process_batch_hijack( self, p: StableDiffusionProcessingImg2Img, input_dir: str, output_dir: str, inpaint_mask_dir: str, @@ -496,7 +512,7 @@ def set_numpy_seed(p: processing.StableDiffusionProcessing) -> Optional[int]: if getattr(unit, 'input_mode', InputMode.SIMPLE) == InputMode.BATCH: input_images = [] for img in unit.batch_images: - unit.image = img # TODO: SAM extension should use new API + unit.image = img input_image, _ = cn_script.choose_input_image(p, unit, idx) input_images.append(input_image) else: @@ -510,6 +526,8 @@ def set_numpy_seed(p: processing.StableDiffusionProcessing) -> Optional[int]: for idx, input_image in enumerate(input_images): a1111_mask_image : Optional[Image.Image] = getattr(p, "image_mask", None) + if a1111_mask_image and isinstance(a1111_mask_image, list): + a1111_mask_image = a1111_mask_image[idx] if 'inpaint' in unit.module and not image_has_mask(input_image) and a1111_mask_image is not None: a1111_mask = np.array(prepare_mask(a1111_mask_image, p)) if a1111_mask.ndim == 2: From 5039e32c01f5a15d36b9de92356ffa112adc25ae Mon Sep 17 00:00:00 2001 From: Chengsong Zhang Date: Wed, 25 Oct 2023 03:45:19 -0500 Subject: [PATCH 17/54] correct hack --- scripts/animatediff.py | 3 +- scripts/animatediff_cn.py | 272 +--------------------------- scripts/animatediff_i2ibatch.py | 308 ++++++++++++++++++++++++++++++++ scripts/animatediff_ui.py | 11 +- 4 files changed, 322 insertions(+), 272 deletions(-) create mode 100644 scripts/animatediff_i2ibatch.py diff --git a/scripts/animatediff.py b/scripts/animatediff.py index a778fed3..d347e263 100644 --- a/scripts/animatediff.py +++ b/scripts/animatediff.py @@ -15,6 +15,7 @@ from scripts.animatediff_prompt import AnimateDiffPromptSchedule from scripts.animatediff_output import AnimateDiffOutput from scripts.animatediff_ui import AnimateDiffProcess, AnimateDiffUiGroup +from scripts.animatediff_i2ibatch import animatediff_i2ibatch from scripts.animatediff_infotext import update_infotext script_dir = scripts.basedir() @@ -61,7 +62,7 @@ def before_process(self, p: StableDiffusionProcessing, params: AnimateDiffProces def before_process_batch(self, p: StableDiffusionProcessing, params: AnimateDiffProcess, **kwargs): if isinstance(params, dict): params = AnimateDiffProcess(**params) - if params.enable and isinstance(p, StableDiffusionProcessingImg2Img): + if params.enable and isinstance(p, StableDiffusionProcessingImg2Img) and not hasattr(p, '_animatediff_i2i_batch'): AnimateDiffI2VLatent().randomize(p, params) diff --git a/scripts/animatediff_cn.py b/scripts/animatediff_cn.py index 90fbe0a8..3b6e3fbd 100644 --- a/scripts/animatediff_cn.py +++ b/scripts/animatediff_cn.py @@ -7,26 +7,18 @@ import cv2 import numpy as np import torch -import hashlib -from PIL import Image, ImageFilter, ImageOps, UnidentifiedImageError -from modules import processing, shared, scripts, img2img, devices, masking, sd_samplers, images +from PIL import Image, ImageFilter, ImageOps +from modules import processing, shared, masking, images from modules.paths import data_path from modules.processing import (StableDiffusionProcessing, StableDiffusionProcessingImg2Img, - StableDiffusionProcessingTxt2Img, - process_images, - create_binary_mask, - create_random_tensors, - images_tensor_to_samples, - setup_color_correction, - opt_f) -from modules.shared import opts -from modules.sd_samplers_common import images_tensor_to_samples, approximation_indexes + StableDiffusionProcessingTxt2Img) from scripts.animatediff_logger import logger_animatediff as logger from scripts.animatediff_ui import AnimateDiffProcess from scripts.animatediff_prompt import AnimateDiffPromptSchedule from scripts.animatediff_infotext import update_infotext +from scripts.animatediff_i2ibatch import animatediff_i2ibatch class AnimateDiffControl: @@ -121,6 +113,7 @@ def hacked_processing_process_images_hijack(self, p: StableDiffusionProcessing, if getattr(unit, 'input_mode', InputMode.SIMPLE) == InputMode.BATCH: unit.batch_images = unit.batch_images[:params.video_length] + animatediff_i2ibatch.cap_init_image(p, params) prompt_scheduler.parse_prompt(p) update_infotext(p, params) return getattr(processing, '__controlnet_original_process_images_inner')(p, *args, **kwargs) @@ -129,267 +122,12 @@ def hacked_processing_process_images_hijack(self, p: StableDiffusionProcessing, BatchHijack.processing_process_images_hijack = hacked_processing_process_images_hijack processing.process_images_inner = instance.processing_process_images_hijack - def hacked_i2i_init(self, all_prompts, all_seeds, all_subseeds): # only hack this when i2i-batch with batch mask - self.image_cfg_scale: float = self.image_cfg_scale if shared.sd_model.cond_stage_key == "edit" else None - - self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model) - crop_regions = [] - paste_to = [] - masks_for_overlay = [] - - image_masks = self.image_mask - - for idx, image_mask in enumerate(image_masks): - # image_mask is passed in as RGBA by Gradio to support alpha masks, - # but we still want to support binary masks. - image_mask = create_binary_mask(image_mask) - - if self.inpainting_mask_invert: - image_mask = ImageOps.invert(image_mask) - - if self.mask_blur_x > 0: - np_mask = np.array(image_mask) - kernel_size = 2 * int(2.5 * self.mask_blur_x + 0.5) + 1 - np_mask = cv2.GaussianBlur(np_mask, (kernel_size, 1), self.mask_blur_x) - image_mask = Image.fromarray(np_mask) - - if self.mask_blur_y > 0: - np_mask = np.array(image_mask) - kernel_size = 2 * int(2.5 * self.mask_blur_y + 0.5) + 1 - np_mask = cv2.GaussianBlur(np_mask, (1, kernel_size), self.mask_blur_y) - image_mask = Image.fromarray(np_mask) - - if self.inpaint_full_res: - masks_for_overlay.append(image_mask) - mask = image_mask.convert('L') - crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding) - crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height) - crop_regions.append(crop_region) - x1, y1, x2, y2 = crop_region - - mask = mask.crop(crop_region) - image_mask = images.resize_image(2, mask, self.width, self.height) - paste_to.append((x1, y1, x2-x1, y2-y1)) - else: - image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height) - np_mask = np.array(image_mask) - np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8) - masks_for_overlay.append(Image.fromarray(np_mask)) - - image_masks[idx] = image_mask - - self.mask_for_overlay = masks_for_overlay[0] # only for saving purpose - if paste_to: - self.paste_to = paste_to[0] - self._animatediff_paste_to_full = paste_to - - self.overlay_images = [] - add_color_corrections = opts.img2img_color_correction and self.color_corrections is None - if add_color_corrections: - self.color_corrections = [] - imgs = [] - for idx, img in enumerate(self.init_images): - latent_mask = (self.latent_mask[idx] if isinstance(self.latent_mask, list) else self.latent_mask) if self.latent_mask is not None else image_masks[idx] - # Save init image - if opts.save_init_img: - self.init_img_hash = hashlib.md5(img.tobytes()).hexdigest() - images.save_image(img, path=opts.outdir_init_images, basename=None, forced_filename=self.init_img_hash, save_to_dirs=False) - - image = images.flatten(img, opts.img2img_background_color) - - if not crop_regions and self.resize_mode != 3: - image = images.resize_image(self.resize_mode, image, self.width, self.height) - - if image_masks: - image_masked = Image.new('RGBa', (image.width, image.height)) - image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(masks_for_overlay[idx].convert('L'))) - - self.overlay_images.append(image_masked.convert('RGBA')) - - # crop_region is not None if we are doing inpaint full res - if crop_regions: - image = image.crop(crop_regions[idx]) - image = images.resize_image(2, image, self.width, self.height) - - if image_masks: - if self.inpainting_fill != 1: - image = masking.fill(image, latent_mask) - - if add_color_corrections: - self.color_corrections.append(setup_color_correction(image)) - - image = np.array(image).astype(np.float32) / 255.0 - image = np.moveaxis(image, 2, 0) - - imgs.append(image) - - if len(imgs) == 1: - batch_images = np.expand_dims(imgs[0], axis=0).repeat(self.batch_size, axis=0) - if self.overlay_images is not None: - self.overlay_images = self.overlay_images * self.batch_size - - if self.color_corrections is not None and len(self.color_corrections) == 1: - self.color_corrections = self.color_corrections * self.batch_size - - elif len(imgs) <= self.batch_size: - self.batch_size = len(imgs) - batch_images = np.array(imgs) - else: - raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less") - - image = torch.from_numpy(batch_images) - image = image.to(shared.device, dtype=devices.dtype_vae) - - if opts.sd_vae_encode_method != 'Full': - self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method - - self.init_latent = images_tensor_to_samples(image, approximation_indexes.get(opts.sd_vae_encode_method), self.sd_model) - devices.torch_gc() - - if self.resize_mode == 3: - self.init_latent = torch.nn.functional.interpolate(self.init_latent, size=(self.height // opt_f, self.width // opt_f), mode="bilinear") - - if image_masks is not None: - def process_letmask(init_mask): - # init_mask = latent_mask - latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2])) - latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255 - latmask = latmask[0] - latmask = np.around(latmask) - return np.tile(latmask[None], (4, 1, 1)) - - if self.latent_mask is not None and not isinstance(self.latent_mask, list): - latmask = process_letmask(self.latent_mask) - elif self.latent_mask is not None: - if isinstance(self.latent_mask, list): - latmask = [process_letmask(x) for x in self.latent_mask] - else: - latmask = [process_letmask(x) for x in image_masks] - latmask = np.stack(latmask, axis=0) - - self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(self.sd_model.dtype) - self.nmask = torch.asarray(latmask).to(shared.device).type(self.sd_model.dtype) - - # this needs to be fixed to be done in sample() using actual seeds for batches - if self.inpainting_fill == 2: - self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], all_seeds[0:self.init_latent.shape[0]]) * self.nmask - elif self.inpainting_fill == 3: - self.init_latent = self.init_latent * self.mask - - self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, self.init_latent, image_masks) # let's ignore this image_masks which is related to inpaint model with different arch - - def hacked_img2img_process_batch_hijack( - self, p: StableDiffusionProcessingImg2Img, input_dir: str, output_dir: str, inpaint_mask_dir: str, - args, to_scale=False, scale_by=1.0, use_png_info=False, png_info_props=None, png_info_dir=None): - p._animatediff_i2i_batch = 1 # i2i-batch mode, ordinary - output_dir = output_dir.strip() - processing.fix_seed(p) - - images = list(shared.walk_files(input_dir, allowed_extensions=(".png", ".jpg", ".jpeg", ".webp", ".tif", ".tiff"))) - - is_inpaint_batch = False - if inpaint_mask_dir: - inpaint_masks = shared.listfiles(inpaint_mask_dir) - is_inpaint_batch = bool(inpaint_masks) - - if is_inpaint_batch: - assert len(inpaint_masks) == 1 or len(inpaint_masks) == len(images), 'The number of masks must be 1 or equal to the number of images.' - logger.info(f"\n[i2i batch] Inpaint batch is enabled. {len(inpaint_masks)} masks found.") - if len(inpaint_masks) > 1: # batch mask - p.init = MethodType(hacked_i2i_init, p) - - logger.info(f"[i2i batch] Will process {len(images)} images, creating {p.n_iter} new videos.") - - # extract "default" params to use in case getting png info fails - prompt = p.prompt - negative_prompt = p.negative_prompt - seed = p.seed - cfg_scale = p.cfg_scale - sampler_name = p.sampler_name - steps = p.steps - frame_images = [] - frame_masks = [] - - for i, image in enumerate(images): - - try: - img = Image.open(image) - except UnidentifiedImageError as e: - print(e) - continue - # Use the EXIF orientation of photos taken by smartphones. - img = ImageOps.exif_transpose(img) - - if to_scale: - p.width = int(img.width * scale_by) - p.height = int(img.height * scale_by) - - frame_images.append(img) - - image_path = Path(image) - if is_inpaint_batch: - if len(inpaint_masks) == 1: - mask_image_path = inpaint_masks[0] - p.image_mask = Image.open(mask_image_path) - else: - # try to find corresponding mask for an image using index matching - mask_image_path = inpaint_masks[i] - frame_masks.append(Image.open(mask_image_path)) - - mask_image = Image.open(mask_image_path) - p.image_mask = mask_image - - if use_png_info: - try: - info_img = frame_images[0] - if png_info_dir: - info_img_path = os.path.join(png_info_dir, os.path.basename(image)) - info_img = Image.open(info_img_path) - from modules import images as imgutil - from modules.generation_parameters_copypaste import parse_generation_parameters - geninfo, _ = imgutil.read_info_from_image(info_img) - parsed_parameters = parse_generation_parameters(geninfo) - parsed_parameters = {k: v for k, v in parsed_parameters.items() if k in (png_info_props or {})} - except Exception: - parsed_parameters = {} - - p.prompt = prompt + (" " + parsed_parameters["Prompt"] if "Prompt" in parsed_parameters else "") - p.negative_prompt = negative_prompt + (" " + parsed_parameters["Negative prompt"] if "Negative prompt" in parsed_parameters else "") - p.seed = int(parsed_parameters.get("Seed", seed)) - p.cfg_scale = float(parsed_parameters.get("CFG scale", cfg_scale)) - p.sampler_name = parsed_parameters.get("Sampler", sampler_name) - p.steps = int(parsed_parameters.get("Steps", steps)) - - p.init_images = frame_images - if len(frame_masks) > 0: - p.image_mask = frame_masks - - proc = scripts.scripts_img2img.run(p, *args) # we should not support this, but just leave it here - if proc is None: - if output_dir: - p.outpath_samples = output_dir - p.override_settings['save_to_dirs'] = False - if p.n_iter > 1 or p.batch_size > 1: - p.override_settings['samples_filename_pattern'] = f'{image_path.stem}-[generation_number]' - else: - p.override_settings['samples_filename_pattern'] = f'{image_path.stem}' - process_images(p) - else: - logger.warn("Warning: you are using an unsupported external script. AnimateDiff may not work properly.") - - self.original_img2img_process_batch_hijack = BatchHijack.img2img_process_batch_hijack - BatchHijack.img2img_process_batch_hijack = hacked_img2img_process_batch_hijack - img2img.process_batch = instance.img2img_process_batch_hijack - def restore_batchhijack(self): from scripts.batch_hijack import BatchHijack, instance BatchHijack.processing_process_images_hijack = self.original_processing_process_images_hijack self.original_processing_process_images_hijack = None processing.process_images_inner = instance.processing_process_images_hijack - BatchHijack.img2img_process_batch_hijack = self.original_img2img_process_batch_hijack - self.original_img2img_process_batch_hijack = None - img2img.process_batch = instance.img2img_process_batch_hijack def hack_cn(self): diff --git a/scripts/animatediff_i2ibatch.py b/scripts/animatediff_i2ibatch.py new file mode 100644 index 00000000..fa3237e6 --- /dev/null +++ b/scripts/animatediff_i2ibatch.py @@ -0,0 +1,308 @@ +from pathlib import Path +from types import MethodType + +import os +import cv2 +import numpy as np +import torch +import hashlib +from PIL import Image, ImageOps, UnidentifiedImageError +from modules import processing, shared, scripts, img2img, devices, masking, sd_samplers, images +from modules.processing import (StableDiffusionProcessingImg2Img, + process_images, + create_binary_mask, + create_random_tensors, + images_tensor_to_samples, + setup_color_correction, + opt_f) +from modules.shared import opts +from modules.sd_samplers_common import images_tensor_to_samples, approximation_indexes + +from scripts.animatediff_logger import logger_animatediff as logger + + +class AnimateDiffI2IBatch: + + def __init__(self): + self.original_img2img_process_batch_hijack = None + + + def hack(self): + logger.info("Hacking i2i-batch.") + def hacked_i2i_init(self, all_prompts, all_seeds, all_subseeds): # only hack this when i2i-batch with batch mask + self.image_cfg_scale: float = self.image_cfg_scale if shared.sd_model.cond_stage_key == "edit" else None + + self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model) + crop_regions = [] + paste_to = [] + masks_for_overlay = [] + + image_masks = self.image_mask + + for idx, image_mask in enumerate(image_masks): + # image_mask is passed in as RGBA by Gradio to support alpha masks, + # but we still want to support binary masks. + image_mask = create_binary_mask(image_mask) + + if self.inpainting_mask_invert: + image_mask = ImageOps.invert(image_mask) + + if self.mask_blur_x > 0: + np_mask = np.array(image_mask) + kernel_size = 2 * int(2.5 * self.mask_blur_x + 0.5) + 1 + np_mask = cv2.GaussianBlur(np_mask, (kernel_size, 1), self.mask_blur_x) + image_mask = Image.fromarray(np_mask) + + if self.mask_blur_y > 0: + np_mask = np.array(image_mask) + kernel_size = 2 * int(2.5 * self.mask_blur_y + 0.5) + 1 + np_mask = cv2.GaussianBlur(np_mask, (1, kernel_size), self.mask_blur_y) + image_mask = Image.fromarray(np_mask) + + if self.inpaint_full_res: + masks_for_overlay.append(image_mask) + mask = image_mask.convert('L') + crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding) + crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height) + crop_regions.append(crop_region) + x1, y1, x2, y2 = crop_region + + mask = mask.crop(crop_region) + image_mask = images.resize_image(2, mask, self.width, self.height) + paste_to.append((x1, y1, x2-x1, y2-y1)) + else: + image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height) + np_mask = np.array(image_mask) + np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8) + masks_for_overlay.append(Image.fromarray(np_mask)) + + image_masks[idx] = image_mask + + self.mask_for_overlay = masks_for_overlay[0] # only for saving purpose + if paste_to: + self.paste_to = paste_to[0] + self._animatediff_paste_to_full = paste_to + + self.overlay_images = [] + add_color_corrections = opts.img2img_color_correction and self.color_corrections is None + if add_color_corrections: + self.color_corrections = [] + imgs = [] + for idx, img in enumerate(self.init_images): + latent_mask = (self.latent_mask[idx] if isinstance(self.latent_mask, list) else self.latent_mask) if self.latent_mask is not None else image_masks[idx] + # Save init image + if opts.save_init_img: + self.init_img_hash = hashlib.md5(img.tobytes()).hexdigest() + images.save_image(img, path=opts.outdir_init_images, basename=None, forced_filename=self.init_img_hash, save_to_dirs=False) + + image = images.flatten(img, opts.img2img_background_color) + + if not crop_regions and self.resize_mode != 3: + image = images.resize_image(self.resize_mode, image, self.width, self.height) + + if image_masks: + image_masked = Image.new('RGBa', (image.width, image.height)) + image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(masks_for_overlay[idx].convert('L'))) + + self.overlay_images.append(image_masked.convert('RGBA')) + + # crop_region is not None if we are doing inpaint full res + if crop_regions: + image = image.crop(crop_regions[idx]) + image = images.resize_image(2, image, self.width, self.height) + + if image_masks: + if self.inpainting_fill != 1: + image = masking.fill(image, latent_mask) + + if add_color_corrections: + self.color_corrections.append(setup_color_correction(image)) + + image = np.array(image).astype(np.float32) / 255.0 + image = np.moveaxis(image, 2, 0) + + imgs.append(image) + + if len(imgs) == 1: + batch_images = np.expand_dims(imgs[0], axis=0).repeat(self.batch_size, axis=0) + if self.overlay_images is not None: + self.overlay_images = self.overlay_images * self.batch_size + + if self.color_corrections is not None and len(self.color_corrections) == 1: + self.color_corrections = self.color_corrections * self.batch_size + + elif len(imgs) <= self.batch_size: + self.batch_size = len(imgs) + batch_images = np.array(imgs) + else: + raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less") + + image = torch.from_numpy(batch_images) + image = image.to(shared.device, dtype=devices.dtype_vae) + + if opts.sd_vae_encode_method != 'Full': + self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method + + self.init_latent = images_tensor_to_samples(image, approximation_indexes.get(opts.sd_vae_encode_method), self.sd_model) + devices.torch_gc() + + if self.resize_mode == 3: + self.init_latent = torch.nn.functional.interpolate(self.init_latent, size=(self.height // opt_f, self.width // opt_f), mode="bilinear") + + if image_masks is not None: + def process_letmask(init_mask): + # init_mask = latent_mask + latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2])) + latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255 + latmask = latmask[0] + latmask = np.around(latmask) + return np.tile(latmask[None], (4, 1, 1)) + + if self.latent_mask is not None and not isinstance(self.latent_mask, list): + latmask = process_letmask(self.latent_mask) + elif self.latent_mask is not None: + if isinstance(self.latent_mask, list): + latmask = [process_letmask(x) for x in self.latent_mask] + else: + latmask = [process_letmask(x) for x in image_masks] + latmask = np.stack(latmask, axis=0) + + self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(self.sd_model.dtype) + self.nmask = torch.asarray(latmask).to(shared.device).type(self.sd_model.dtype) + + # this needs to be fixed to be done in sample() using actual seeds for batches + if self.inpainting_fill == 2: + self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], all_seeds[0:self.init_latent.shape[0]]) * self.nmask + elif self.inpainting_fill == 3: + self.init_latent = self.init_latent * self.mask + + self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, self.init_latent, image_masks) # let's ignore this image_masks which is related to inpaint model with different arch + + def hacked_img2img_process_batch_hijack( + self, p: StableDiffusionProcessingImg2Img, input_dir: str, output_dir: str, inpaint_mask_dir: str, + args, to_scale=False, scale_by=1.0, use_png_info=False, png_info_props=None, png_info_dir=None): + p._animatediff_i2i_batch = 1 # i2i-batch mode, ordinary + output_dir = output_dir.strip() + processing.fix_seed(p) + + images = list(shared.walk_files(input_dir, allowed_extensions=(".png", ".jpg", ".jpeg", ".webp", ".tif", ".tiff"))) + + is_inpaint_batch = False + if inpaint_mask_dir: + inpaint_masks = shared.listfiles(inpaint_mask_dir) + is_inpaint_batch = bool(inpaint_masks) + + if is_inpaint_batch: + assert len(inpaint_masks) == 1 or len(inpaint_masks) == len(images), 'The number of masks must be 1 or equal to the number of images.' + logger.info(f"\n[i2i batch] Inpaint batch is enabled. {len(inpaint_masks)} masks found.") + if len(inpaint_masks) > 1: # batch mask + p.init = MethodType(hacked_i2i_init, p) + + logger.info(f"[i2i batch] Will process {len(images)} images, creating {p.n_iter} new videos.") + + # extract "default" params to use in case getting png info fails + prompt = p.prompt + negative_prompt = p.negative_prompt + seed = p.seed + cfg_scale = p.cfg_scale + sampler_name = p.sampler_name + steps = p.steps + frame_images = [] + frame_masks = [] + + for i, image in enumerate(images): + + try: + img = Image.open(image) + except UnidentifiedImageError as e: + print(e) + continue + # Use the EXIF orientation of photos taken by smartphones. + img = ImageOps.exif_transpose(img) + + if to_scale: + p.width = int(img.width * scale_by) + p.height = int(img.height * scale_by) + + frame_images.append(img) + + image_path = Path(image) + if is_inpaint_batch: + if len(inpaint_masks) == 1: + mask_image_path = inpaint_masks[0] + p.image_mask = Image.open(mask_image_path) + else: + # try to find corresponding mask for an image using index matching + mask_image_path = inpaint_masks[i] + frame_masks.append(Image.open(mask_image_path)) + + mask_image = Image.open(mask_image_path) + p.image_mask = mask_image + + if use_png_info: + try: + info_img = frame_images[0] + if png_info_dir: + info_img_path = os.path.join(png_info_dir, os.path.basename(image)) + info_img = Image.open(info_img_path) + from modules import images as imgutil + from modules.generation_parameters_copypaste import parse_generation_parameters + geninfo, _ = imgutil.read_info_from_image(info_img) + parsed_parameters = parse_generation_parameters(geninfo) + parsed_parameters = {k: v for k, v in parsed_parameters.items() if k in (png_info_props or {})} + except Exception: + parsed_parameters = {} + + p.prompt = prompt + (" " + parsed_parameters["Prompt"] if "Prompt" in parsed_parameters else "") + p.negative_prompt = negative_prompt + (" " + parsed_parameters["Negative prompt"] if "Negative prompt" in parsed_parameters else "") + p.seed = int(parsed_parameters.get("Seed", seed)) + p.cfg_scale = float(parsed_parameters.get("CFG scale", cfg_scale)) + p.sampler_name = parsed_parameters.get("Sampler", sampler_name) + p.steps = int(parsed_parameters.get("Steps", steps)) + + p.init_images = frame_images + if len(frame_masks) > 0: + p.image_mask = frame_masks + + proc = scripts.scripts_img2img.run(p, *args) # we should not support this, but just leave it here + if proc is None: + if output_dir: + p.outpath_samples = output_dir + p.override_settings['save_to_dirs'] = False + if p.n_iter > 1 or p.batch_size > 1: + p.override_settings['samples_filename_pattern'] = f'{image_path.stem}-[generation_number]' + else: + p.override_settings['samples_filename_pattern'] = f'{image_path.stem}' + process_images(p) + else: + logger.warn("Warning: you are using an unsupported external script. AnimateDiff may not work properly.") + + from scripts.batch_hijack import BatchHijack, instance + self.original_img2img_process_batch_hijack = BatchHijack.img2img_process_batch_hijack + BatchHijack.img2img_process_batch_hijack = hacked_img2img_process_batch_hijack + img2img.process_batch = instance.img2img_process_batch_hijack + + + def restore(self): + logger.info("Restoring i2i-batch.") + from scripts.batch_hijack import BatchHijack, instance + if self.original_img2img_process_batch_hijack is not None: + BatchHijack.img2img_process_batch_hijack = self.original_img2img_process_batch_hijack + self.original_img2img_process_batch_hijack = None + img2img.process_batch = instance.img2img_process_batch_hijack + + + def cap_init_image(self, p: StableDiffusionProcessingImg2Img, params): + if params.enable and isinstance(p, StableDiffusionProcessingImg2Img) and hasattr(p, '_animatediff_i2i_batch'): + if len(p.init_images) > params.video_length: + p.init_images = p.init_images[:params.video_length] + if p.image_mask and isinstance(p.image_mask, list) and len(p.image_mask) > params.video_length: + p.image_mask = p.image_mask[:params.video_length] + if len(p.init_images) < params.video_length: + params.video_length = len(p.init_images) + if len(p.init_images) < params.batch_size: + params.batch_size = len(p.init_images) + + + +animatediff_i2ibatch = AnimateDiffI2IBatch() diff --git a/scripts/animatediff_ui.py b/scripts/animatediff_ui.py index f6b47899..7bedccb2 100644 --- a/scripts/animatediff_ui.py +++ b/scripts/animatediff_ui.py @@ -7,6 +7,7 @@ from modules.processing import StableDiffusionProcessing from scripts.animatediff_mm import mm_animatediff as motion_module +from scripts.animatediff_i2ibatch import animatediff_i2ibatch class ToolButton(gr.Button, gr.components.FormComponent): @@ -67,7 +68,9 @@ def __init__( def get_list(self, is_img2img: bool): list_var = list(vars(self).values()) - if not is_img2img: + if is_img2img: + animatediff_i2ibatch.hack() + else: list_var = list_var[:-5] return list_var @@ -301,12 +304,12 @@ def update_frames(video_source): type="pil", ) with gr.Row(): - unload = gr.Button( - value="Move motion module to CPU (default if lowvram)" - ) + unload = gr.Button(value="Move motion module to CPU (default if lowvram)") remove = gr.Button(value="Remove motion module from any memory") + restore_i2ibatch = gr.Button(value="Restore img2img batch") unload.click(fn=motion_module.unload) remove.click(fn=motion_module.remove) + restore_i2ibatch.click(fn=animatediff_i2ibatch.restore) return self.register_unit(is_img2img) From d809875cf4208021410ce7b39125401c304474db Mon Sep 17 00:00:00 2001 From: Chengsong Zhang Date: Wed, 25 Oct 2023 04:09:31 -0500 Subject: [PATCH 18/54] works --- scripts/animatediff.py | 1 - scripts/animatediff_i2ibatch.py | 30 +++++++++++------------------- scripts/animatediff_ui.py | 2 -- 3 files changed, 11 insertions(+), 22 deletions(-) diff --git a/scripts/animatediff.py b/scripts/animatediff.py index d347e263..a5593c02 100644 --- a/scripts/animatediff.py +++ b/scripts/animatediff.py @@ -15,7 +15,6 @@ from scripts.animatediff_prompt import AnimateDiffPromptSchedule from scripts.animatediff_output import AnimateDiffOutput from scripts.animatediff_ui import AnimateDiffProcess, AnimateDiffUiGroup -from scripts.animatediff_i2ibatch import animatediff_i2ibatch from scripts.animatediff_infotext import update_infotext script_dir = scripts.basedir() diff --git a/scripts/animatediff_i2ibatch.py b/scripts/animatediff_i2ibatch.py index fa3237e6..96c28ccf 100644 --- a/scripts/animatediff_i2ibatch.py +++ b/scripts/animatediff_i2ibatch.py @@ -23,12 +23,10 @@ class AnimateDiffI2IBatch: - def __init__(self): - self.original_img2img_process_batch_hijack = None - - def hack(self): logger.info("Hacking i2i-batch.") + original_img2img_process_batch = img2img.process_batch + def hacked_i2i_init(self, all_prompts, all_seeds, all_subseeds): # only hack this when i2i-batch with batch mask self.image_cfg_scale: float = self.image_cfg_scale if shared.sd_model.cond_stage_key == "edit" else None @@ -179,9 +177,15 @@ def process_letmask(init_mask): self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, self.init_latent, image_masks) # let's ignore this image_masks which is related to inpaint model with different arch def hacked_img2img_process_batch_hijack( - self, p: StableDiffusionProcessingImg2Img, input_dir: str, output_dir: str, inpaint_mask_dir: str, + p: StableDiffusionProcessingImg2Img, input_dir: str, output_dir: str, inpaint_mask_dir: str, args, to_scale=False, scale_by=1.0, use_png_info=False, png_info_props=None, png_info_dir=None): - p._animatediff_i2i_batch = 1 # i2i-batch mode, ordinary + if p.scripts: + for script in p.scripts.alwayson_scripts: + if script.title().lower() == "animatediff": + p._animatediff_i2i_batch = 1 # i2i-batch mode, ordinary + + if not hasattr(p, '_animatediff_i2i_batch'): + return original_img2img_process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale, scale_by, use_png_info, png_info_props, png_info_dir) output_dir = output_dir.strip() processing.fix_seed(p) @@ -277,19 +281,7 @@ def hacked_img2img_process_batch_hijack( else: logger.warn("Warning: you are using an unsupported external script. AnimateDiff may not work properly.") - from scripts.batch_hijack import BatchHijack, instance - self.original_img2img_process_batch_hijack = BatchHijack.img2img_process_batch_hijack - BatchHijack.img2img_process_batch_hijack = hacked_img2img_process_batch_hijack - img2img.process_batch = instance.img2img_process_batch_hijack - - - def restore(self): - logger.info("Restoring i2i-batch.") - from scripts.batch_hijack import BatchHijack, instance - if self.original_img2img_process_batch_hijack is not None: - BatchHijack.img2img_process_batch_hijack = self.original_img2img_process_batch_hijack - self.original_img2img_process_batch_hijack = None - img2img.process_batch = instance.img2img_process_batch_hijack + img2img.process_batch = hacked_img2img_process_batch_hijack def cap_init_image(self, p: StableDiffusionProcessingImg2Img, params): diff --git a/scripts/animatediff_ui.py b/scripts/animatediff_ui.py index 7bedccb2..054b2ed4 100644 --- a/scripts/animatediff_ui.py +++ b/scripts/animatediff_ui.py @@ -306,10 +306,8 @@ def update_frames(video_source): with gr.Row(): unload = gr.Button(value="Move motion module to CPU (default if lowvram)") remove = gr.Button(value="Remove motion module from any memory") - restore_i2ibatch = gr.Button(value="Restore img2img batch") unload.click(fn=motion_module.unload) remove.click(fn=motion_module.remove) - restore_i2ibatch.click(fn=animatediff_i2ibatch.restore) return self.register_unit(is_img2img) From 2f30b6c72d9b40c3444c90c5686015dae5e10a82 Mon Sep 17 00:00:00 2001 From: Chengsong Zhang Date: Wed, 25 Oct 2023 04:24:32 -0500 Subject: [PATCH 19/54] readme --- README.md | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index ecee8cfe..9e8bdf91 100644 --- a/README.md +++ b/README.md @@ -50,7 +50,7 @@ You might also be interested in another extension I created: [Segment Anything f - `2023/10/16`: [v1.9.2](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.9.2): Add context generator to completely remove any closed loop, prompt travel support closed loop, infotext fully supported including prompt travel, README refactor - `2023/10/19`: [v1.9.3](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.9.3): Support webp output format. See [#233](https://github.com/continue-revolution/sd-webui-animatediff/pull/233) for more information. - `2023/10/21`: [v1.9.4](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.9.4): Save prompt travel to output images, `Reverse` merged to `Closed loop` (See [WebUI Parameters](#webui-parameters)), remove `TimestepEmbedSequential` hijack, remove `hints.js`, better explanation of several context-related parameters. -- `2023/10/??`: [v1.10.0](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.10.0): ? +- `2023/10/25`: [v1.10.0](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.10.0): Support img2img batch. You need ControlNet installed to make it work properly (you do not need to enable ControlNet). See [ControlNet V2V](#controlnet-v2v) for more information. For future update plan, please query [here](https://github.com/continue-revolution/sd-webui-animatediff/pull/224). @@ -177,10 +177,11 @@ smile ## ControlNet V2V -You need to go to txt2img and submit source video or path to frames. Each ControlNet will find control images according to this priority: +You need to go to txt2img / img2img-batch and submit source video or path to frames. Each ControlNet will find control images according to this priority: 1. ControlNet `Single Image` tab or `Batch` tab. Simply upload a control image or a directory of control frames is enough. +1. Img2img Batch tab `Input directory` if you are using img2img batch. If you upload a directory of control frames, it will be the source control for ALL ControlNet units that you enable without submitting a control image or a path to ControlNet panel. 1. AnimateDiff `Video Source`. If you upload a video through `Video Source`, it will be the source control for ALL ControlNet units that you enable without submitting a control image or a path to ControlNet panel. -1. AnimateDiff `Video Path`. If you upload a video through `Video Path`, it will be the source control for ALL ControlNet units that you enable without submitting a control image or a path to ControlNet panel. +1. AnimateDiff `Video Path`. If you upload a path to frames through `Video Path`, it will be the source control for ALL ControlNet units that you enable without submitting a control image or a path to ControlNet panel. `Number of frames` will be capped to the minimum number of images among all **folders** you provide. Each control image in each folder will be applied to one single frame. If you upload one single image for a ControlNet unit, that image will control **ALL** frames. From 55f525ef3f7bb133c9c0703e7394002cf0028b67 Mon Sep 17 00:00:00 2001 From: Chengsong Zhang Date: Sat, 28 Oct 2023 19:53:37 -0500 Subject: [PATCH 20/54] partially finish hotshotxl, have to deal with another bug --- .gitignore | 4 ++-- mm_zoo.json | 38 ------------------------------- motion_module.py | 48 +++++++++++++++++++++++---------------- scripts/animatediff_mm.py | 27 ++++------------------ 4 files changed, 34 insertions(+), 83 deletions(-) delete mode 100644 mm_zoo.json diff --git a/.gitignore b/.gitignore index 096491c2..98fba9a2 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,4 @@ __pycache__ -model/*.ckpt -model/*.pth +model/*.* +model/*.* TODO.md \ No newline at end of file diff --git a/mm_zoo.json b/mm_zoo.json deleted file mode 100644 index 1091d54d..00000000 --- a/mm_zoo.json +++ /dev/null @@ -1,38 +0,0 @@ -{ - "aa7fd8a200a89031edd84487e2a757c5315460eca528fa70d4b3885c399bffd5": { - "name": "mm_sd_14.ckpt", - "arch": 1 - }, - "cf16ea656cb16124990c8e2c70a29c793f9841f3a2223073fac8bd89ebd9b69a": { - "name": "mm_sd_15.ckpt", - "arch": 1 - }, - "0aaf157b9c51a0ae07cb5d9ea7c51299f07bddc6f52025e1f9bb81cd763631df": { - "name": "mm-Stabilized_high.pth", - "arch": 1 - }, - "39de8b71b1c09f10f4602f5d585d82771a60d3cf282ba90215993e06afdfe875": { - "name": "mm-Stabilized_mid.pth", - "arch": 1 - }, - "3cb569f7ce3dc6a10aa8438e666265cb9be3120d8f205de6a456acf46b6c99f4": { - "name": "temporaldiff-v1-animatediff.ckpt", - "arch": 1 - }, - "0ba406706cd9ba7e272e96ba01107cb44f63b65ab052808cb8bdb9ed404ca68a": { - "name": "mm_sd_v14.safetensors", - "arch": 1 - }, - "69ed0f5fef82b110aca51bcab73b21104242bc65d6ab4b8b2a2a94d31cad1bf0": { - "name": "mm_sd_v15_v2.ckpt", - "arch": 2 - }, - "096f2f9ce84dabb79f25961a1266c1e6cda8c3bce0a4d2a60cf656ce5a747642": { - "name": "mm_sd_v15.safetensors", - "arch": 2 - }, - "1cf5d6d6062b8d3d72cd5cdcabb578befd79b7a3e1ea41bf15d3ca594c0998da": { - "name": "mm_sd_v15_v2.safetensors", - "arch": 2 - } -} diff --git a/motion_module.py b/motion_module.py index b61fc041..538cb913 100644 --- a/motion_module.py +++ b/motion_module.py @@ -1,3 +1,4 @@ +from enum import Enum from typing import Optional import torch @@ -11,6 +12,12 @@ import math +class MotionModuleType(Enum): + AnimateDiffV1 = "AnimateDiff V1, Yuwei GUo, Shanghai AI Lab" + AnimateDiffV2 = "AnimateDiff V2, Yuwei Guo, Shanghai AI Lab" + HotShotXL = "HotShot-XL, John Mullan, Natural Synthetics Inc" + + def zero_module(module): # Zero out the parameters of a module and return it. for p in module.parameters(): @@ -19,37 +26,38 @@ def zero_module(module): class MotionWrapper(nn.Module): - def __init__(self, mm_hash: str, using_v2: bool): + def __init__(self, mm_name: str, mm_hash: str, mm_type: MotionModuleType): super().__init__() - if using_v2: - max_len = 32 - else: - max_len = 24 + is_v2 = mm_type == MotionModuleType.AnimateDiffV2 + is_sdxl = mm_type == MotionModuleType.HotShotXL + max_len = 32 if is_v2 else 24 + in_channels = (320, 640, 1280) if is_sdxl else (320, 640, 1280, 1280) self.down_blocks = nn.ModuleList([]) self.up_blocks = nn.ModuleList([]) - for c in (320, 640, 1280, 1280): - self.down_blocks.append(MotionModule(c, max_len=max_len)) - for c in (1280, 1280, 640, 320): - self.up_blocks.append(MotionModule(c, is_up=True, max_len=max_len)) - if using_v2: - self.mid_block = MotionModule(1280, max_len=max_len, is_mid=using_v2) + for c in in_channels: + self.down_blocks.append(MotionModule(c, num_mm=2, max_len=max_len, is_sdxl=is_sdxl)) + self.up_blocks.insert(0,MotionModule(c, num_mm=3, max_len=max_len, is_sdxl=is_sdxl)) + if is_v2: + self.mid_block = MotionModule(1280, num_mm=1, max_len=max_len) + self.mm_name = mm_name + self.mm_type = mm_type self.mm_hash = mm_hash - self.using_v2 = using_v2 class MotionModule(nn.Module): - def __init__(self, in_channels, is_up=False, is_mid=False, max_len=24): + def __init__(self, in_channels, num_mm, max_len, is_sdxl=False): super().__init__() - if is_mid: - self.motion_modules = nn.ModuleList([get_motion_module(in_channels, max_len)]) + motion_modules = nn.ModuleList([get_motion_module(in_channels, max_len) for _ in range(num_mm)]) + if is_sdxl: + self.temporal_attentions = motion_modules else: - self.motion_modules = nn.ModuleList([get_motion_module(in_channels, max_len), get_motion_module(in_channels, max_len)]) - if is_up: - self.motion_modules.append(get_motion_module(in_channels, max_len)) + self.motion_modules = motion_modules + -def get_motion_module(in_channels, max_len): - return VanillaTemporalModule(in_channels=in_channels, temporal_position_encoding_max_len=max_len) +def get_motion_module(in_channels, max_len, is_sdxl=False): + vtm = VanillaTemporalModule(in_channels=in_channels, temporal_position_encoding_max_len=max_len) + return vtm.temporal_transformer if is_sdxl else vtm class VanillaTemporalModule(nn.Module): diff --git a/scripts/animatediff_mm.py b/scripts/animatediff_mm.py index 6f61ea02..2a9f857b 100644 --- a/scripts/animatediff_mm.py +++ b/scripts/animatediff_mm.py @@ -37,15 +37,14 @@ def _load(self, model_name): shared.opts.data.get("animatediff_model_path", os.path.join(self.script_dir, "model")), model_name, ) - model_hash, using_v2, guess = self._hash(model_path, model_name) if not os.path.isfile(model_path): raise RuntimeError("Please download models manually.") - if self.mm is None or self.mm.mm_hash != model_hash: + if self.mm is None or self.mm.mm_name != model_name: logger.info(f"Loading motion module {model_name} from {model_path}") + model_hash = hashes.sha256(model_path, f"AnimateDiff/{model_name}") mm_state_dict = sd_models.read_state_dict(model_path) - if guess: - using_v2 = "mid_block.motion_modules.0.temporal_transformer.proj_out.bias" in mm_state_dict.keys() - logger.warn(f"Guessed mm architecture : {'v2' if using_v2 else 'v1'}") + using_v2 = "mid_block.motion_modules.0.temporal_transformer.proj_out.bias" in mm_state_dict.keys() + logger.warn(f"Guessed mm architecture : {'v2' if using_v2 else 'v1'}") self.mm = MotionWrapper(model_hash, using_v2) missed_keys = self.mm.load_state_dict(mm_state_dict) logger.warn(f"Missing keys {missed_keys}") @@ -54,24 +53,6 @@ def _load(self, model_name): self.mm.half() - def _hash(self, model_path: str, model_name="mm_sd_v15.ckpt"): - model_hash = hashes.sha256(model_path, f"AnimateDiff/{model_name}") - with open(os.path.join(self.script_dir, "mm_zoo.json"), "r") as f: - model_zoo = json.load(f) - if model_hash in model_zoo: - model_official_name = model_zoo[model_hash]["name"] - logger.info(f"You are using tested mm {model_official_name}.") - return model_hash, model_zoo[model_hash]["arch"] == 2, False - else: - logger.warn( - f"You are using unknown mm {model_name}. " - "Either your download is incomplete or your model has not been tested. " - "Please use at your own risk. " - "AnimateDiff will guess mm architecture via state_dict." - ) - return model_hash, False, True - - def inject(self, sd_model, model_name="mm_sd_v15.ckpt"): unet = sd_model.model.diffusion_model self._load(model_name) From 601793844f0f3f5ba8784433a5a3fd2b7466f8fa Mon Sep 17 00:00:00 2001 From: Chengsong Zhang Date: Sat, 28 Oct 2023 22:01:44 -0500 Subject: [PATCH 21/54] unknown bug --- motion_module.py | 51 +++++++++++++++------- scripts/animatediff.py | 2 +- scripts/animatediff_mm.py | 89 +++++++++++++++++++-------------------- 3 files changed, 81 insertions(+), 61 deletions(-) diff --git a/motion_module.py b/motion_module.py index 538cb913..442f5da0 100644 --- a/motion_module.py +++ b/motion_module.py @@ -18,6 +18,17 @@ class MotionModuleType(Enum): HotShotXL = "HotShot-XL, John Mullan, Natural Synthetics Inc" + @staticmethod + def get_mm_type(state_dict: dict): + keys = list(state_dict.keys()) + if any(["mid_block" in k for k in keys]): + return MotionModuleType.AnimateDiffV2 + elif any(["temporal_attentions" in k for k in keys]): + return MotionModuleType.HotShotXL + else: + return MotionModuleType.AnimateDiffV1 + + def zero_module(module): # Zero out the parameters of a module and return it. for p in module.parameters(): @@ -28,16 +39,16 @@ def zero_module(module): class MotionWrapper(nn.Module): def __init__(self, mm_name: str, mm_hash: str, mm_type: MotionModuleType): super().__init__() - is_v2 = mm_type == MotionModuleType.AnimateDiffV2 - is_sdxl = mm_type == MotionModuleType.HotShotXL - max_len = 32 if is_v2 else 24 - in_channels = (320, 640, 1280) if is_sdxl else (320, 640, 1280, 1280) + self.is_v2 = mm_type == MotionModuleType.AnimateDiffV2 + self.is_sdxl = mm_type == MotionModuleType.HotShotXL + max_len = 32 if self.is_v2 else 24 + in_channels = (320, 640, 1280) if self.is_sdxl else (320, 640, 1280, 1280) self.down_blocks = nn.ModuleList([]) self.up_blocks = nn.ModuleList([]) for c in in_channels: - self.down_blocks.append(MotionModule(c, num_mm=2, max_len=max_len, is_sdxl=is_sdxl)) - self.up_blocks.insert(0,MotionModule(c, num_mm=3, max_len=max_len, is_sdxl=is_sdxl)) - if is_v2: + self.down_blocks.append(MotionModule(c, num_mm=2, max_len=max_len, is_sdxl=self.is_sdxl)) + self.up_blocks.insert(0,MotionModule(c, num_mm=3, max_len=max_len, is_sdxl=self.is_sdxl)) + if self.is_v2: self.mid_block = MotionModule(1280, num_mm=1, max_len=max_len) self.mm_name = mm_name self.mm_type = mm_type @@ -47,7 +58,7 @@ def __init__(self, mm_name: str, mm_hash: str, mm_type: MotionModuleType): class MotionModule(nn.Module): def __init__(self, in_channels, num_mm, max_len, is_sdxl=False): super().__init__() - motion_modules = nn.ModuleList([get_motion_module(in_channels, max_len) for _ in range(num_mm)]) + motion_modules = nn.ModuleList([get_motion_module(in_channels, max_len, is_sdxl) for _ in range(num_mm)]) if is_sdxl: self.temporal_attentions = motion_modules else: @@ -55,8 +66,8 @@ def __init__(self, in_channels, num_mm, max_len, is_sdxl=False): -def get_motion_module(in_channels, max_len, is_sdxl=False): - vtm = VanillaTemporalModule(in_channels=in_channels, temporal_position_encoding_max_len=max_len) +def get_motion_module(in_channels, max_len, is_sdxl): + vtm = VanillaTemporalModule(in_channels=in_channels, temporal_position_encoding_max_len=max_len, is_sdxl=is_sdxl) return vtm.temporal_transformer if is_sdxl else vtm @@ -72,6 +83,7 @@ def __init__( temporal_position_encoding_max_len = 24, temporal_attention_dim_div = 1, zero_initialize = True, + is_sdxl = False, ): super().__init__() @@ -84,6 +96,7 @@ def __init__( cross_frame_attention_mode=cross_frame_attention_mode, temporal_position_encoding=temporal_position_encoding, temporal_position_encoding_max_len=temporal_position_encoding_max_len, + is_sdxl=is_sdxl, ) if zero_initialize: @@ -113,6 +126,7 @@ def __init__( cross_frame_attention_mode = None, temporal_position_encoding = False, temporal_position_encoding_max_len = 24, + is_sdxl = False, ): super().__init__() @@ -137,6 +151,7 @@ def __init__( cross_frame_attention_mode=cross_frame_attention_mode, temporal_position_encoding=temporal_position_encoding, temporal_position_encoding_max_len=temporal_position_encoding_max_len, + is_sdxl=is_sdxl, ) for d in range(num_layers) ] @@ -181,6 +196,7 @@ def __init__( cross_frame_attention_mode = None, temporal_position_encoding = False, temporal_position_encoding_max_len = 24, + is_sdxl = False, ): super().__init__() @@ -203,6 +219,7 @@ def __init__( cross_frame_attention_mode=cross_frame_attention_mode, temporal_position_encoding=temporal_position_encoding, temporal_position_encoding_max_len=temporal_position_encoding_max_len, + is_sdxl=is_sdxl, ) ) norms.append(nn.LayerNorm(dim)) @@ -234,7 +251,8 @@ def __init__( self, d_model, dropout = 0., - max_len = 24 + max_len = 24, + is_sdxl = False, ): super().__init__() self.dropout = nn.Dropout(p=dropout) @@ -243,10 +261,11 @@ def __init__( pe = torch.zeros(1, max_len, d_model) pe[0, :, 0::2] = torch.sin(position * div_term) pe[0, :, 1::2] = torch.cos(position * div_term) - self.register_buffer('pe', pe) + self.register_buffer('positional_encoding' if is_sdxl else 'pe', pe) + self.is_sdxl = is_sdxl def forward(self, x): - x = x + self.pe[:, :x.size(1)] + x = x + self.positional_encoding[:, :x.size(1)] if self.is_sdxl else self.pe[:, :x.size(1)] return self.dropout(x) @@ -505,7 +524,8 @@ def __init__( attention_mode = None, cross_frame_attention_mode = None, temporal_position_encoding = False, - temporal_position_encoding_max_len = 24, + temporal_position_encoding_max_len = 24, + is_sdxl = False, *args, **kwargs ): super().__init__(*args, **kwargs) @@ -517,7 +537,8 @@ def __init__( self.pos_encoder = PositionalEncoding( kwargs["query_dim"], dropout=0., - max_len=temporal_position_encoding_max_len + max_len=temporal_position_encoding_max_len, + is_sdxl=is_sdxl, ) if (temporal_position_encoding and attention_mode == "Temporal") else None def extra_repr(self): diff --git a/scripts/animatediff.py b/scripts/animatediff.py index a5593c02..a5efe0b5 100644 --- a/scripts/animatediff.py +++ b/scripts/animatediff.py @@ -50,7 +50,7 @@ def before_process(self, p: StableDiffusionProcessing, params: AnimateDiffProces params.set_p(p) motion_module.inject(p.sd_model, params.model) self.prompt_scheduler = AnimateDiffPromptSchedule() - self.lora_hacker = AnimateDiffLora(motion_module.mm.using_v2) + self.lora_hacker = AnimateDiffLora(motion_module.mm.is_v2) self.lora_hacker.hack() self.cfg_hacker = AnimateDiffInfV2V(p, self.prompt_scheduler) self.cfg_hacker.hack(params) diff --git a/scripts/animatediff_mm.py b/scripts/animatediff_mm.py index 2a9f857b..8e16a98b 100644 --- a/scripts/animatediff_mm.py +++ b/scripts/animatediff_mm.py @@ -4,15 +4,11 @@ import torch from einops import rearrange -# from ldm.modules.attention import SpatialTransformer -# from ldm.modules.diffusionmodules.openaimodel import (TimestepBlock, -# TimestepEmbedSequential) from ldm.modules.diffusionmodules.util import GroupNorm32 from modules import hashes, shared, sd_models from modules.devices import cpu, device, torch_gc -from motion_module import MotionWrapper -# from motion_module import VanillaTemporalModule +from motion_module import MotionWrapper, MotionModuleType from scripts.animatediff_logger import logger_animatediff as logger @@ -25,7 +21,6 @@ def __init__(self): self.prev_alpha_cumprod = None self.prev_alpha_cumprod_prev = None self.gn32_original_forward = None - # self.tes_original_forward = None def set_script_dir(self, script_dir): @@ -43,9 +38,9 @@ def _load(self, model_name): logger.info(f"Loading motion module {model_name} from {model_path}") model_hash = hashes.sha256(model_path, f"AnimateDiff/{model_name}") mm_state_dict = sd_models.read_state_dict(model_path) - using_v2 = "mid_block.motion_modules.0.temporal_transformer.proj_out.bias" in mm_state_dict.keys() - logger.warn(f"Guessed mm architecture : {'v2' if using_v2 else 'v1'}") - self.mm = MotionWrapper(model_hash, using_v2) + model_type = MotionModuleType.get_mm_type(mm_state_dict) + logger.info(f"Guessed {model_name} architecture: {model_type}") + self.mm = MotionWrapper(model_name, model_hash, model_type) missed_keys = self.mm.load_state_dict(mm_state_dict) logger.warn(f"Missing keys {missed_keys}") self.mm.to(device).eval() @@ -56,26 +51,19 @@ def _load(self, model_name): def inject(self, sd_model, model_name="mm_sd_v15.ckpt"): unet = sd_model.model.diffusion_model self._load(model_name) - self.gn32_original_forward = GroupNorm32.forward - gn32_original_forward = self.gn32_original_forward - # self.tes_original_forward = TimestepEmbedSequential.forward - - # def mm_tes_forward(self, x, emb, context=None): - # for layer in self: - # if isinstance(layer, TimestepBlock): - # x = layer(x, emb) - # elif isinstance(layer, (SpatialTransformer, VanillaTemporalModule)): - # x = layer(x, context) - # else: - # x = layer(x) - # return x - - # TimestepEmbedSequential.forward = mm_tes_forward - if self.mm.using_v2: - logger.info(f"Injecting motion module {model_name} into SD1.5 UNet middle block.") + inject_sdxl = sd_model.is_sdxl or self.mm.is_sdxl + sd_ver = "SDXL" if sd_model.is_sdxl else "SD1.5" + if sd_model.is_sdxl != self.mm.is_sdxl: + logger.warn(f"Motion module incompatible with SD. You are using {sd_ver} with {self.mm.mm_type}." + f"The injection and inference will go on but the result might be sub-optimal.") + + if self.mm.is_v2: + logger.info(f"Injecting motion module {model_name} into {sd_ver} UNet middle block.") unet.middle_block.insert(-1, self.mm.mid_block.motion_modules[0]) else: logger.info(f"Hacking GroupNorm32 forward function.") + self.gn32_original_forward = GroupNorm32.forward + gn32_original_forward = self.gn32_original_forward def groupnorm32_mm_forward(self, x): x = rearrange(x, "(b f) c h w -> b c f h w", b=2) @@ -85,49 +73,60 @@ def groupnorm32_mm_forward(self, x): GroupNorm32.forward = groupnorm32_mm_forward - logger.info(f"Injecting motion module {model_name} into SD1.5 UNet input blocks.") + logger.info(f"Injecting motion module {model_name} into {sd_ver} UNet input blocks.") for mm_idx, unet_idx in enumerate([1, 2, 4, 5, 7, 8, 10, 11]): + if inject_sdxl and mm_idx >= 6: + break mm_idx0, mm_idx1 = mm_idx // 2, mm_idx % 2 - unet.input_blocks[unet_idx].append( - self.mm.down_blocks[mm_idx0].motion_modules[mm_idx1] - ) + mm_inject = getattr(self.mm.down_blocks[mm_idx0], "temporal_attentions" if self.mm.is_sdxl else "motion_modules")[mm_idx1] + unet.input_blocks[unet_idx].append(mm_inject) - logger.info(f"Injecting motion module {model_name} into SD1.5 UNet output blocks.") + logger.info(f"Injecting motion module {model_name} into {sd_ver} UNet output blocks.") for unet_idx in range(12): + if inject_sdxl and unet_idx >= 9: + break mm_idx0, mm_idx1 = unet_idx // 3, unet_idx % 3 - if unet_idx % 3 == 2 and unet_idx != 11: - unet.output_blocks[unet_idx].insert( - -1, self.mm.up_blocks[mm_idx0].motion_modules[mm_idx1] - ) + mm_inject = getattr(self.mm.up_blocks[mm_idx0], "temporal_attentions" if self.mm.is_sdxl else "motion_modules")[mm_idx1] + if unet_idx % 3 == 2 and unet_idx != (8 if self.mm.is_sdxl else 11): + unet.output_blocks[unet_idx].insert(-1, mm_inject) else: - unet.output_blocks[unet_idx].append( - self.mm.up_blocks[mm_idx0].motion_modules[mm_idx1] - ) + unet.output_blocks[unet_idx].append(mm_inject) self._set_ddim_alpha(sd_model) self._set_layer_mapping(sd_model) logger.info(f"Injection finished.") + # print sd_model + with open('/home/conrevo/SD/stable-diffusion-webui/extensions/sd-webui-animatediff/model/sd_model.txt', 'w') as f: + # redirect stdout + print(sd_model, file=f) + def restore(self, sd_model): + inject_sdxl = sd_model.is_sdxl or self.mm.is_sdxl + sd_ver = "SDXL" if sd_model.is_sdxl else "SD1.5" self._restore_ddim_alpha(sd_model) unet = sd_model.model.diffusion_model - logger.info(f"Removing motion module from SD1.5 UNet input blocks.") + logger.info(f"Removing motion module from {sd_ver} UNet input blocks.") for unet_idx in [1, 2, 4, 5, 7, 8, 10, 11]: + if inject_sdxl and unet_idx >= 9: + break unet.input_blocks[unet_idx].pop(-1) - logger.info(f"Removing motion module from SD1.5 UNet output blocks.") + logger.info(f"Removing motion module from {sd_ver} UNet output blocks.") for unet_idx in range(12): - if unet_idx % 3 == 2 and unet_idx != 11: + if inject_sdxl and unet_idx >= 9: + break + if unet_idx % 3 == 2 and unet_idx != (8 if self.mm.is_sdxl else 11): unet.output_blocks[unet_idx].pop(-2) else: unet.output_blocks[unet_idx].pop(-1) - if self.mm.using_v2: - logger.info(f"Removing motion module from SD1.5 UNet middle block.") + if self.mm.is_v2: + logger.info(f"Removing motion module from {sd_ver} UNet middle block.") unet.middle_block.pop(-2) else: logger.info(f"Restoring GroupNorm32 forward function.") GroupNorm32.forward = self.gn32_original_forward - # TimestepEmbedSequential.forward = self.tes_original_forward + self.gn32_original_forward = None logger.info(f"Removal finished.") if shared.cmd_opts.lowvram: self.unload() @@ -140,7 +139,7 @@ def _set_ddim_alpha(self, sd_model): betas = torch.linspace( beta_start, beta_end, - sd_model.num_timesteps, + sd_model.num_idx if sd_model.is_sdxl else sd_model.num_timesteps, dtype=torch.float32, device=device, ) From 98d806d8a0a747dec43e62a49909491aa7277d6e Mon Sep 17 00:00:00 2001 From: Chengsong Zhang Date: Sat, 28 Oct 2023 22:19:40 -0500 Subject: [PATCH 22/54] remove print --- scripts/animatediff_mm.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/scripts/animatediff_mm.py b/scripts/animatediff_mm.py index 8e16a98b..d93fa4f8 100644 --- a/scripts/animatediff_mm.py +++ b/scripts/animatediff_mm.py @@ -95,11 +95,6 @@ def groupnorm32_mm_forward(self, x): self._set_ddim_alpha(sd_model) self._set_layer_mapping(sd_model) logger.info(f"Injection finished.") - # print sd_model - with open('/home/conrevo/SD/stable-diffusion-webui/extensions/sd-webui-animatediff/model/sd_model.txt', 'w') as f: - # redirect stdout - print(sd_model, file=f) - def restore(self, sd_model): From 0c9ddd21208ca5e2b6c04ba8eae8c9d5ba6d571f Mon Sep 17 00:00:00 2001 From: Chengsong Zhang Date: Sat, 28 Oct 2023 23:13:32 -0500 Subject: [PATCH 23/54] fix mm --- motion_module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/motion_module.py b/motion_module.py index 442f5da0..6ac8a242 100644 --- a/motion_module.py +++ b/motion_module.py @@ -265,7 +265,7 @@ def __init__( self.is_sdxl = is_sdxl def forward(self, x): - x = x + self.positional_encoding[:, :x.size(1)] if self.is_sdxl else self.pe[:, :x.size(1)] + x = x + (self.positional_encoding[:, :x.size(1)] if self.is_sdxl else self.pe[:, :x.size(1)]) return self.dropout(x) From dafd0194e585d0a594c0b06c1ac5182ed034cb49 Mon Sep 17 00:00:00 2001 From: Chengsong Zhang Date: Sat, 28 Oct 2023 23:13:45 -0500 Subject: [PATCH 24/54] remove beta, prev --- scripts/animatediff_mm.py | 18 +----------------- 1 file changed, 1 insertion(+), 17 deletions(-) diff --git a/scripts/animatediff_mm.py b/scripts/animatediff_mm.py index d93fa4f8..81ea7cd5 100644 --- a/scripts/animatediff_mm.py +++ b/scripts/animatediff_mm.py @@ -17,9 +17,7 @@ class AnimateDiffMM: def __init__(self): self.mm: MotionWrapper = None self.script_dir = None - self.prev_beta = None self.prev_alpha_cumprod = None - self.prev_alpha_cumprod_prev = None self.gn32_original_forward = None @@ -134,24 +132,14 @@ def _set_ddim_alpha(self, sd_model): betas = torch.linspace( beta_start, beta_end, - sd_model.num_idx if sd_model.is_sdxl else sd_model.num_timesteps, + 1000 if sd_model.is_sdxl else sd_model.num_timesteps, # TODO: I'm not sure which parameter to use here for SDXL dtype=torch.float32, device=device, ) alphas = 1.0 - betas alphas_cumprod = torch.cumprod(alphas, dim=0) - alphas_cumprod_prev = torch.cat( - ( - torch.tensor([1.0], dtype=torch.float32, device=device), - alphas_cumprod[:-1], - ) - ) - self.prev_beta = sd_model.betas self.prev_alpha_cumprod = sd_model.alphas_cumprod - self.prev_alpha_cumprod_prev = sd_model.alphas_cumprod_prev - sd_model.betas = betas sd_model.alphas_cumprod = alphas_cumprod - sd_model.alphas_cumprod_prev = alphas_cumprod_prev def _set_layer_mapping(self, sd_model): @@ -162,12 +150,8 @@ def _set_layer_mapping(self, sd_model): def _restore_ddim_alpha(self, sd_model): logger.info(f"Restoring DDIM alpha.") - sd_model.betas = self.prev_beta sd_model.alphas_cumprod = self.prev_alpha_cumprod - sd_model.alphas_cumprod_prev = self.prev_alpha_cumprod_prev - self.prev_beta = None self.prev_alpha_cumprod = None - self.prev_alpha_cumprod_prev = None def unload(self): From 5670b5b5407776d661616a3983061585a52eacce Mon Sep 17 00:00:00 2001 From: Chengsong Zhang Date: Sat, 28 Oct 2023 23:40:45 -0500 Subject: [PATCH 25/54] basic use supported. CN and prompt travel still on the way --- scripts/animatediff_infv2v.py | 6 +++++- scripts/animatediff_mm.py | 14 ++++++++++---- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/scripts/animatediff_infv2v.py b/scripts/animatediff_infv2v.py index 8274800f..97c3b07c 100644 --- a/scripts/animatediff_infv2v.py +++ b/scripts/animatediff_infv2v.py @@ -185,7 +185,11 @@ def mm_sd_forward(self, x_in, sigma_in, cond_in, image_cond_in, make_condition_d else: _context = context mm_cn_select(_context) - out = self.inner_model(x_in[_context], sigma_in[_context], cond=make_condition_dict(cond_in[_context], image_cond_in[_context])) + out = self.inner_model( + x_in[_context], sigma_in[_context], + cond=make_condition_dict( + cond_in[_context] if not isinstance(cond_in, dict) else {k: v[_context] for k, v in cond_in.items()}, + image_cond_in[_context])) x_out = x_out.to(dtype=out.dtype) x_out[_context] = out mm_cn_restore(_context) diff --git a/scripts/animatediff_mm.py b/scripts/animatediff_mm.py index 81ea7cd5..3c18c763 100644 --- a/scripts/animatediff_mm.py +++ b/scripts/animatediff_mm.py @@ -1,10 +1,8 @@ import gc -import json import os import torch from einops import rearrange -from ldm.modules.diffusionmodules.util import GroupNorm32 from modules import hashes, shared, sd_models from modules.devices import cpu, device, torch_gc @@ -59,7 +57,11 @@ def inject(self, sd_model, model_name="mm_sd_v15.ckpt"): logger.info(f"Injecting motion module {model_name} into {sd_ver} UNet middle block.") unet.middle_block.insert(-1, self.mm.mid_block.motion_modules[0]) else: - logger.info(f"Hacking GroupNorm32 forward function.") + logger.info(f"Hacking {sd_ver} GroupNorm32 forward function.") + if self.mm.is_sdxl: + from sgm.modules.diffusionmodules.util import GroupNorm32 + else: + from ldm.modules.diffusionmodules.util import GroupNorm32 self.gn32_original_forward = GroupNorm32.forward gn32_original_forward = self.gn32_original_forward @@ -117,7 +119,11 @@ def restore(self, sd_model): logger.info(f"Removing motion module from {sd_ver} UNet middle block.") unet.middle_block.pop(-2) else: - logger.info(f"Restoring GroupNorm32 forward function.") + logger.info(f"Restoring {sd_ver} GroupNorm32 forward function.") + if self.mm.is_sdxl: + from sgm.modules.diffusionmodules.util import GroupNorm32 + else: + from ldm.modules.diffusionmodules.util import GroupNorm32 GroupNorm32.forward = self.gn32_original_forward self.gn32_original_forward = None logger.info(f"Removal finished.") From 13d1741400b1854e5020bc90f24deca7b879fdf2 Mon Sep 17 00:00:00 2001 From: Chengsong Zhang Date: Sun, 29 Oct 2023 04:52:16 -0500 Subject: [PATCH 26/54] uncomment all control models for sdxl --- scripts/animatediff_cn.py | 37 +++++++++++++++++------------------ scripts/animatediff_infv2v.py | 19 +++++++++--------- scripts/animatediff_mm.py | 3 ++- 3 files changed, 29 insertions(+), 30 deletions(-) diff --git a/scripts/animatediff_cn.py b/scripts/animatediff_cn.py index 3b6e3fbd..7ad2f49f 100644 --- a/scripts/animatediff_cn.py +++ b/scripts/animatediff_cn.py @@ -136,10 +136,10 @@ def hack_cn(self): def hacked_main_entry(self, p: StableDiffusionProcessing): from scripts import external_code, global_state, hook - # from scripts.controlnet_lora import bind_control_lora # do not support control lora for sdxl + from scripts.controlnet_lora import bind_control_lora from scripts.adapter import Adapter, Adapter_light, StyleAdapter from scripts.batch_hijack import InputMode - # from scripts.controlnet_lllite import PlugableControlLLLite, clear_all_lllite # do not support controlllite for sdxl + from scripts.controlnet_lllite import PlugableControlLLLite, clear_all_lllite from scripts.controlmodel_ipadapter import (PlugableIPAdapter, clear_all_ip_adapter) from scripts.hook import ControlModelType, ControlParams, UnetHook @@ -196,14 +196,14 @@ def set_numpy_seed(p: processing.StableDiffusionProcessing) -> Optional[int]: unet = sd_ldm.model.diffusion_model self.noise_modifier = None - # setattr(p, 'controlnet_control_loras', []) # do not support control lora for sdxl + setattr(p, 'controlnet_control_loras', []) if self.latest_network is not None: # always restore (~0.05s) self.latest_network.restore() # always clear (~0.05s) - # clear_all_lllite() # do not support controlllite for sdxl + clear_all_lllite() clear_all_ip_adapter() self.enabled_units = cn_script.get_enabled_units(p) @@ -242,10 +242,10 @@ def set_numpy_seed(p: processing.StableDiffusionProcessing) -> Optional[int]: model_net = cn_script.load_control_model(p, unet, unit.model) model_net.reset() - # if getattr(model_net, 'is_control_lora', False): # do not support control lora for sdxl - # control_lora = model_net.control_model - # bind_control_lora(unet, control_lora) - # p.controlnet_control_loras.append(control_lora) + if getattr(model_net, 'is_control_lora', False): + control_lora = model_net.control_model + bind_control_lora(unet, control_lora) + p.controlnet_control_loras.append(control_lora) if getattr(unit, 'input_mode', InputMode.SIMPLE) == InputMode.BATCH: input_images = [] @@ -352,8 +352,8 @@ def set_numpy_seed(p: processing.StableDiffusionProcessing) -> Optional[int]: control_model_type = ControlModelType.T2I_StyleAdapter elif isinstance(model_net, PlugableIPAdapter): control_model_type = ControlModelType.IPAdapter - # elif isinstance(model_net, PlugableControlLLLite): # do not support controlllite for sdxl - # control_model_type = ControlModelType.Controlllite + elif isinstance(model_net, PlugableControlLLLite): + control_model_type = ControlModelType.Controlllite if control_model_type is ControlModelType.ControlNet: global_average_pooling = model_net.control_model.global_average_pooling @@ -577,15 +577,14 @@ def recolor_intensity_post_processing(x, i): start=param.start_guidance_percent, end=param.stop_guidance_percent ) - # Do not support controlllite for sdxl - # if param.control_model_type == ControlModelType.Controlllite: - # param.control_model.hook( - # model=unet, - # cond=param.hint_cond, - # weight=param.weight, - # start=param.start_guidance_percent, - # end=param.stop_guidance_percent - # ) + if param.control_model_type == ControlModelType.Controlllite: + param.control_model.hook( + model=unet, + cond=param.hint_cond, + weight=param.weight, + start=param.start_guidance_percent, + end=param.stop_guidance_percent + ) self.detected_map = detected_maps self.post_processors = post_processors diff --git a/scripts/animatediff_infv2v.py b/scripts/animatediff_infv2v.py index 97c3b07c..8dc5532d 100644 --- a/scripts/animatediff_infv2v.py +++ b/scripts/animatediff_infv2v.py @@ -100,7 +100,6 @@ def hack(self, params: AnimateDiffProcess): def mm_cn_select(context: List[int]): # take control images for current context. - # controlllite is for sdxl and we do not support it. reserve here for future use is needed. if cn_script is not None and cn_script.latest_network is not None: from scripts.hook import ControlModelType for control in cn_script.latest_network.control_params: @@ -132,11 +131,11 @@ def mm_cn_select(context: List[int]): control.control_model.image_emb = control.control_model.image_emb[context] control.control_model.uncond_image_emb_backup = control.control_model.uncond_image_emb control.control_model.uncond_image_emb = control.control_model.uncond_image_emb[context] - # if control.control_model_type == ControlModelType.Controlllite: - # for module in control.control_model.modules.values(): - # if module.cond_image.shape[0] > len(context): - # module.cond_image_backup = module.cond_image - # module.set_cond_image(module.cond_image[context]) + if control.control_model_type == ControlModelType.Controlllite: + for module in control.control_model.modules.values(): + if module.cond_image.shape[0] > len(context): + module.cond_image_backup = module.cond_image + module.set_cond_image(module.cond_image[context]) def mm_cn_restore(context: List[int]): # restore control images for next context @@ -172,10 +171,10 @@ def mm_cn_restore(context: List[int]): # control.control_model.uncond_image_emb_backup[context] = control.control_model.uncond_image_emb control.control_model.image_emb = control.control_model.image_emb_backup control.control_model.uncond_image_emb = control.control_model.uncond_image_emb_backup - # if control.control_model_type == ControlModelType.Controlllite: - # for module in control.control_model.modules.values(): - # if module.cond_image.shape[0] > len(context): - # module.set_cond_image(module.cond_image_backup) + if control.control_model_type == ControlModelType.Controlllite: + for module in control.control_model.modules.values(): + if module.cond_image.shape[0] > len(context): + module.set_cond_image(module.cond_image_backup) def mm_sd_forward(self, x_in, sigma_in, cond_in, image_cond_in, make_condition_dict): x_out = torch.zeros_like(x_in) diff --git a/scripts/animatediff_mm.py b/scripts/animatediff_mm.py index 3c18c763..39d24b6d 100644 --- a/scripts/animatediff_mm.py +++ b/scripts/animatediff_mm.py @@ -138,7 +138,8 @@ def _set_ddim_alpha(self, sd_model): betas = torch.linspace( beta_start, beta_end, - 1000 if sd_model.is_sdxl else sd_model.num_timesteps, # TODO: I'm not sure which parameter to use here for SDXL + # TODO: I'm not sure which parameter to use here for SDXL + 1000 if sd_model.is_sdxl else sd_model.num_timesteps, dtype=torch.float32, device=device, ) From aaca3ebd9c7a0f8a8bee90ad66679e77585927e9 Mon Sep 17 00:00:00 2001 From: Chengsong Zhang Date: Sun, 29 Oct 2023 05:42:54 -0500 Subject: [PATCH 27/54] prompt travel --- scripts/animatediff_mm.py | 4 ++-- scripts/animatediff_prompt.py | 25 +++++++++++++++++++------ 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/scripts/animatediff_mm.py b/scripts/animatediff_mm.py index 39d24b6d..08e26158 100644 --- a/scripts/animatediff_mm.py +++ b/scripts/animatediff_mm.py @@ -50,8 +50,8 @@ def inject(self, sd_model, model_name="mm_sd_v15.ckpt"): inject_sdxl = sd_model.is_sdxl or self.mm.is_sdxl sd_ver = "SDXL" if sd_model.is_sdxl else "SD1.5" if sd_model.is_sdxl != self.mm.is_sdxl: - logger.warn(f"Motion module incompatible with SD. You are using {sd_ver} with {self.mm.mm_type}." - f"The injection and inference will go on but the result might be sub-optimal.") + logger.warn(f"Motion module incompatible with SD. You are using {sd_ver} with {self.mm.mm_type}. " + f"You will see an error afterwards. Even if the injection and inference seem to go on, you will get bad results.") if self.mm.is_v2: logger.info(f"Injecting motion module {model_name} into {sd_ver} UNet middle block.") diff --git a/scripts/animatediff_prompt.py b/scripts/animatediff_prompt.py index bfa51960..35da88a4 100644 --- a/scripts/animatediff_prompt.py +++ b/scripts/animatediff_prompt.py @@ -97,20 +97,33 @@ def single_cond(self, center_frame, video_length: int, cond: torch.Tensor, close dist_next += video_length if key_prev == key_next or dist_prev + dist_next == 0: - return cond[key_prev] + return cond[key_prev] if isinstance(cond, torch.Tensor) else {k: v[key_prev] for k, v in cond.items()} rate = dist_prev / (dist_prev + dist_next) - - return AnimateDiffPromptSchedule.slerp(cond[key_prev], cond[key_next], rate) + if isinstance(cond, torch.Tensor): + return AnimateDiffPromptSchedule.slerp(cond[key_prev], cond[key_next], rate) + else: # isinstance(cond, dict) + return { + k: AnimateDiffPromptSchedule.slerp(v[key_prev], v[key_next], rate) + for k, v in cond.items() + } def multi_cond(self, cond: torch.Tensor, closed_loop = False): if self.prompt_map is None: return cond - cond_list = [] + cond_list = [] if isinstance(cond, torch.Tensor) else {k: [] for k in cond.keys()} for i in range(cond.shape[0]): - cond_list.append(self.single_cond(i, cond.shape[0], cond, closed_loop)) - return torch.stack(cond_list).to(cond.dtype).to(cond.device) + single_cond = self.single_cond(i, cond.shape[0], cond, closed_loop) + if isinstance(cond, torch.Tensor): + cond_list.append(single_cond) + else: + for k, v in single_cond.items(): + cond_list[k].append(v) + if isinstance(cond, torch.Tensor): + return torch.stack(cond_list).to(cond.dtype).to(cond.device) + else: + return {k: torch.stack(v).to(cond[k].dtype).to(cond[k].device) for k, v in cond_list.items()} @staticmethod From 46b61f2d00af567be4273a4a0c98326a783a387a Mon Sep 17 00:00:00 2001 From: Chengsong Zhang Date: Sun, 29 Oct 2023 06:10:54 -0500 Subject: [PATCH 28/54] readme --- README.md | 30 +++++++++++++++++++++++++++--- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 846e8405..f086eae8 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,7 @@ You might also be interested in another extension I created: [Segment Anything f - [Motion LoRA](#motion-lora) - [Prompt Travel](#prompt-travel) - [ControlNet V2V](#controlnet-v2v) +- [HotShot-XL](#hotshot-xl) - [Model Zoo](#model-zoo) - [VRAM](#vram) - [Batch Size](#batch-size) @@ -51,9 +52,11 @@ You might also be interested in another extension I created: [Segment Anything f - `2023/10/19`: [v1.9.3](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.9.3): Support webp output format. See [#233](https://github.com/continue-revolution/sd-webui-animatediff/pull/233) for more information. - `2023/10/21`: [v1.9.4](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.9.4): Save prompt travel to output images, `Reverse` merged to `Closed loop` (See [WebUI Parameters](#webui-parameters)), remove `TimestepEmbedSequential` hijack, remove `hints.js`, better explanation of several context-related parameters. - `2023/10/25`: [v1.10.0](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.10.0): Support img2img batch. You need ControlNet installed to make it work properly (you do not need to enable ControlNet). See [ControlNet V2V](#controlnet-v2v) for more information. +- `2023/10/29`: [v1.11.0](ttps://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.11.0): Support [HotShot-XL](https://github.com/hotshotco/Hotshot-XL) for SDXL. See [HotShot-XL](#hotshot-xl) for more information. For future update plan, please query [here](https://github.com/continue-revolution/sd-webui-animatediff/pull/224). + ## How to Use 1. Update your WebUI to v1.6.0 and ControlNet to v1.1.410, then install this extension via link. I do not plan to support older version. 1. Download motion modules and put the model weights under `stable-diffusion-webui/extensions/sd-webui-animatediff/model/`. If you want to use another directory to save model weights, please go to `Settings/AnimateDiff`. See [model zoo](#model-zoo) for a list of available motion modules. @@ -111,7 +114,7 @@ Just like how you use ControlNet. Here is a sample. Due to the limitation of Web If you enter something smaller than your `Context batch size` other than 0: you will get the first `Number of frames` frames as your output GIF from your whole generation. All following frames will not appear in your generated GIF, but will be saved as PNGs as usual. Do not set `Number of frames` to be something smaler than `Context batch size` other than 0 because of [#213](https://github.com/continue-revolution/sd-webui-animatediff/issues/213). 1. **FPS** — Frames per second, which is how many frames (images) are shown every second. If 16 frames are generated at 8 frames per second, your GIF’s duration is 2 seconds. If you submit a source video, your FPS will be the same as the source video. 1. **Display loop number** — How many times the GIF is played. A value of `0` means the GIF never stops playing. -1. **Context batch size** — How many frames will be passed into the motion module at once. The model is trained with 16 frames, so it’ll give the best results when the number of frames is set to `16`. Choose [1, 24] for V1 motion modules and [1, 32] for V2 motion modules. +1. **Context batch size** — How many frames will be passed into the motion module at once. The SD1.5 motion modules are trained with 16 frames, so it’ll give the best results when the number of frames is set to `16`. SDXL HotShotXL motion modules are trained with 8 frames instead. Choose [1, 24] for V1 / HotShotXL motion modules and [1, 32] for V2 motion modules. 1. **Closed loop** — Closed loop means that this extension will try to make the last frame the same as the first frame. 1. When `Number of frames` > `Context batch size`, including when ControlNet is enabled and the source video frame number > `Context batch size` and `Number of frames` is 0, closed loop will be performed by AnimateDiff infinite context generator. 1. When `Number of frames` <= `Context batch size`, AnimateDiff infinite context generator will not be effective. Only when you choose `A` will AnimateDiff append reversed list of frames to the original list of frames to form closed loop. @@ -190,21 +193,42 @@ For people who want to inpaint videos: enter a folder which contains two sub-fol AnimateDiff in img2img batch will be available in [v1.10.0](https://github.com/continue-revolution/sd-webui-animatediff/pull/224). +## HotShot-XL + +[HotShot-XL](https://github.com/hotshotco/Hotshot-XL) has identical architecture to AnimateDiff. The only 2 difference are +- HotShot-XL is trained with 8 frames instead of 16 frames. You are recommended to set `Context batch size` to 8 for HotShot-XL. +- HotShot-XL has fewer layers because of SDXL. + +Although HotShot-XL has identical structure with AnimateDiff, I strongly discourage you from using AnimateDiff for SDXL, or using HotShot for SD1.5 - you will get severe artifect if you do that. I have decided not to supported that, despite the fact that it is not hard for me to do that. + +Technically all features available for AnimateDiff are also available for HotShot-XL. However, I have not tested all of them. I have tested infinite context generation and prompt travel; I have not tested ControlNet. If you find any bug, please report it to me. + +The difference between this extension and the official [HotShot-XL extension](https://github.com/hotshotco/Hotshot-XL-Automatic1111) is that - you can completely get rid of diffusers if you use this one. + +For VRAM usage, please read [VRAM](#vram). + + ## Model Zoo -- `mm_sd_v14.ckpt` & `mm_sd_v15.ckpt` & `mm_sd_v15_v2.ckpt` by [@guoyww](https://github.com/guoyww): [Google Drive](https://drive.google.com/drive/folders/1EqLC65eR1-W-sGD0Im7fkED6c8GkiNFI) | [HuggingFace](https://huggingface.co/guoyww/animatediff) | [CivitAI](https://civitai.com/models/108836) | [Baidu NetDisk](https://pan.baidu.com/s/18ZpcSM6poBqxWNHtnyMcxg?pwd=et8y) +- `mm_sd_v14.ckpt` & `mm_sd_v15.ckpt` & `mm_sd_v15_v2.ckpt` by [@guoyww](https://github.com/guoyww): [Google Drive](https://drive.google.com/drive/folders/1EqLC65eR1-W-sGD0Im7fkED6c8GkiNFI) | [HuggingFace](https://huggingface.co/guoyww/animatediff/tree/main) | [CivitAI](https://civitai.com/models/108836) - `mm_sd_v14.safetensors` & `mm_sd_v15.safetensors` & `mm_sd_v15_v2.safetensors` by [@neph1](https://github.com/neph1): [HuggingFace](https://huggingface.co/guoyww/animatediff/tree/refs%2Fpr%2F3) +- `mm_sd_v14.fp16.safetensors` & `mm_sd_v15.fp16.safetensors` & `mm_sd_v15_v2.fp16.safetensors` by [@neggles](https://huggingface.co/neggles/): [HuggingFace](https://huggingface.co/neggles/) - `mm-Stabilized_high.pth` & `mm-Stabbilized_mid.pth` by [@manshoety](https://huggingface.co/manshoety): [HuggingFace](https://huggingface.co/manshoety/AD_Stabilized_Motion/tree/main) - `temporaldiff-v1-animatediff.ckpt` by [@CiaraRowles](https://huggingface.co/CiaraRowles): [HuggingFace](https://huggingface.co/CiaraRowles/TemporalDiff/tree/main) +- `hsxl_temporal_layers.safetensors` & `hsxl_tenporal_layers.f16.safetensors` by [@hotshotco](https://huggingface.co/hotshotco/): [HuggingFace](https://huggingface.co/hotshotco/Hotshot-XL/tree/main) ## VRAM -Actual VRAM usage depends on your image size and context batch size. You can try to reduce image size or context batch size to reduce VRAM usage. I list some data tested on Ubuntu 22.04, NVIDIA 4090, torch 2.0.1+cu117, H=W=512, frame=16 (default setting) below. `w/`/`w/o` means `Batch cond/uncond` in `Settings/Optimization` is checked/unchecked. +Actual VRAM usage depends on your image size and context batch size. You can try to reduce image size or context batch size to reduce VRAM usage. + +The following data are SD1.5 + AnimateDiff, tested on Ubuntu 22.04, NVIDIA 4090, torch 2.0.1+cu117, H=W=512, frame=16 (default setting). `w/`/`w/o` means `Batch cond/uncond` in `Settings/Optimization` is checked/unchecked. | Optimization | VRAM w/ | VRAM w/o | | --- | --- | --- | | No optimization | 12.13GB | | | xformers/sdp | 5.60GB | 4.21GB | | sub-quadratic | 10.39GB | | +For SDXL + HotShot + SDP, tested on Ubuntu 22.04, NVIDIA 4090, torch 2.0.1+cu117, H=W=512, frame=8 (default setting), you need 8.66GB VRAM. + ## Batch Size Batch size on WebUI will be replaced by GIF frame number internally: 1 full GIF generated in 1 batch. If you want to generate multiple GIF at once, please change batch number. From d3f3f64f0c726b2c97c2547441b2c63cc1d99b85 Mon Sep 17 00:00:00 2001 From: Chengsong Zhang Date: Sun, 5 Nov 2023 00:52:08 -0500 Subject: [PATCH 29/54] better naming convention --- scripts/animatediff_output.py | 74 ++++++++++++++++++----------------- 1 file changed, 39 insertions(+), 35 deletions(-) diff --git a/scripts/animatediff_output.py b/scripts/animatediff_output.py index af0a3ac7..5f9be166 100644 --- a/scripts/animatediff_output.py +++ b/scripts/animatediff_output.py @@ -24,38 +24,43 @@ def output( for i in range(res.index_of_first_image, len(res.images), step): # frame interpolation replaces video_list with interpolated frames # so make a copy instead of a slice (reference), to avoid modifying res - video_list = [image.copy() for image in res.images[i : i + params.video_length]] + frame_list = [image.copy() for image in res.images[i : i + params.video_length]] seq = images.get_next_sequence_number(f"{p.outpath_samples}/AnimateDiff", "") filename = f"{seq:05}-{res.all_seeds[(i-res.index_of_first_image)]}" video_path_prefix = f"{p.outpath_samples}/AnimateDiff/{filename}" - video_list = self._add_reverse(params, video_list) - video_list = self._interp(p, params, video_list, filename) - video_paths += self._save(params, video_list, video_path_prefix, res, i) + frame_list = self._add_reverse(params, frame_list) + frame_list = self._interp(p, params, frame_list, filename) + video_paths += self._save(params, frame_list, video_path_prefix, res, i) if len(video_paths) > 0: - res.images = video_list if p.is_api else video_paths + if p.is_api: + res.images = frame_list + else: + res.images = video_paths + - def _add_reverse(self, params: AnimateDiffProcess, video_list: list): + def _add_reverse(self, params: AnimateDiffProcess, frame_list: list): if params.video_length <= params.batch_size and params.closed_loop in ['A']: - video_list_reverse = video_list[::-1] - if len(video_list_reverse) > 0: - video_list_reverse.pop(0) - if len(video_list_reverse) > 0: - video_list_reverse.pop(-1) - return video_list + video_list_reverse - return video_list + frame_list_reverse = frame_list[::-1] + if len(frame_list_reverse) > 0: + frame_list_reverse.pop(0) + if len(frame_list_reverse) > 0: + frame_list_reverse.pop(-1) + return frame_list + frame_list_reverse + return frame_list + def _interp( self, p: StableDiffusionProcessing, params: AnimateDiffProcess, - video_list: list, + frame_list: list, filename: str ): if params.interp not in ['FILM']: - return video_list + return frame_list try: from deforum_helpers.frame_interpolation import ( @@ -63,7 +68,7 @@ def _interp( from film_interpolation.film_inference import run_film_interp_infer except ImportError: logger.error("Deforum not found. Please install: https://github.com/deforum-art/deforum-for-automatic1111-webui.git") - return video_list + return frame_list import glob import os @@ -78,13 +83,13 @@ def _interp( film_model_path = os.path.join(film_model_folder, film_model_name) check_and_download_film_model('film_net_fp16.pt', film_model_folder) - film_in_between_frames_count = calculate_frames_to_add(len(video_list), params.interp_x) + film_in_between_frames_count = calculate_frames_to_add(len(frame_list), params.interp_x) # save original frames to tmp folder for deforum input tmp_folder = f"{p.outpath_samples}/AnimateDiff/tmp" input_folder = f"{tmp_folder}/input" os.makedirs(input_folder, exist_ok=True) - for tmp_seq, frame in enumerate(video_list): + for tmp_seq, frame in enumerate(frame_list): imageio.imwrite(f"{input_folder}/{tmp_seq:05}.png", frame) # deforum saves output frames to tmp/{filename} @@ -99,11 +104,11 @@ def _interp( # load deforum output frames and replace video_list interp_frame_paths = sorted(glob.glob(os.path.join(save_folder, '*.png'))) - video_list = [] + frame_list = [] for f in interp_frame_paths: with Image.open(f) as img: img.load() - video_list.append(img) + frame_list.append(img) # if saving PNG, also save interpolated frames if "PNG" in params.format: @@ -115,23 +120,24 @@ def _interp( try: shutil.rmtree(tmp_folder) except OSError as e: print(f"Error: {e}") - return video_list + return frame_list + def _save( self, params: AnimateDiffProcess, - video_list: list, + frame_list: list, video_path_prefix: str, res: Processed, index: int, ): video_paths = [] - video_array = [np.array(v) for v in video_list] + video_array = [np.array(v) for v in frame_list] infotext = res.infotexts[index] use_infotext = shared.opts.enable_pnginfo and infotext is not None if "PNG" in params.format and shared.opts.data.get("animatediff_save_to_custom", False): Path(video_path_prefix).mkdir(exist_ok=True, parents=True) - for i, frame in enumerate(video_list): + for i, frame in enumerate(frame_list): png_filename = f"{video_path_prefix}/{i:05}.png" png_info = PngImagePlugin.PngInfo() png_info.add_text('parameters', res.infotexts[0]) @@ -157,7 +163,7 @@ def _save( "split": ("split", ""), "palgen": ("palettegen", ""), "paluse": ("paletteuse", ""), - "scale": ("scale", f"{video_list[0].width}:{video_list[0].height}") + "scale": ("scale", f"{frame_list[0].width}:{frame_list[0].height}") }, [ ("video_in", "scale", 0, 0), @@ -201,6 +207,7 @@ def _save( ) if shared.opts.data.get("animatediff_optimize_gif_gifsicle", False): self._optimize_gif(video_path_gif) + if "MP4" in params.format: video_path_mp4 = video_path_prefix + ".mp4" video_paths.append(video_path_mp4) @@ -213,9 +220,12 @@ def _save( "sd-webui-animatediff save mp4 requirement: imageio[ffmpeg]", ) imageio.imwrite(video_path_mp4, video_array, fps=params.fps, codec="h264") + if "TXT" in params.format and res.images[index].info is not None: video_path_txt = video_path_prefix + ".txt" - self._save_txt(video_path_txt, infotext) + with open(video_path_txt, "w", encoding="utf8") as file: + file.write(f"{infotext}\n") + if "WEBP" in params.format: if PIL.features.check('webp_anim'): video_path_webp = video_path_prefix + ".webp" @@ -236,8 +246,10 @@ def _save( # see additional Pillow WebP options at https://pillow.readthedocs.io/en/stable/handbook/image-file-formats.html#webp else: logger.warn("WebP animation in Pillow requires system WebP library v0.5.0 or later") + return video_paths + def _optimize_gif(self, video_path: str): try: import pygifsicle @@ -255,18 +267,10 @@ def _optimize_gif(self, video_path: str): except FileNotFoundError: logger.warn("gifsicle not found, required for optimized GIFs, try: apt install gifsicle") - def _save_txt( - self, - video_path: str, - info: str, - ): - with open(video_path, "w", encoding="utf8") as file: - file.write(f"{info}\n") def _encode_video_to_b64(self, paths): videos = [] for v_path in paths: with open(v_path, "rb") as video_file: - encoded_video = base64.b64encode(video_file.read()) - videos.append(encoded_video.decode("utf-8")) + videos.append(base64.b64encode(video_file.read())) return videos From 287f9feb6411cf7f63d15488da31b1c85f7324ae Mon Sep 17 00:00:00 2001 From: Chengsong Zhang Date: Sun, 5 Nov 2023 01:19:14 -0500 Subject: [PATCH 30/54] add date --- scripts/animatediff_output.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/scripts/animatediff_output.py b/scripts/animatediff_output.py index 5f9be166..4e28a193 100644 --- a/scripts/animatediff_output.py +++ b/scripts/animatediff_output.py @@ -19,16 +19,18 @@ def output( ): video_paths = [] logger.info("Merging images into GIF.") - Path(f"{p.outpath_samples}/AnimateDiff").mkdir(exist_ok=True, parents=True) + filename_generator = images.FilenameGenerator(p, p.seed, p.prompt, None) + output_dir = Path(f"{p.outpath_samples}/AnimateDiff/{filename_generator.datetime()}") + output_dir.mkdir(parents=True, exist_ok=True) step = params.video_length if params.video_length > params.batch_size else params.batch_size for i in range(res.index_of_first_image, len(res.images), step): # frame interpolation replaces video_list with interpolated frames # so make a copy instead of a slice (reference), to avoid modifying res frame_list = [image.copy() for image in res.images[i : i + params.video_length]] - seq = images.get_next_sequence_number(f"{p.outpath_samples}/AnimateDiff", "") + seq = images.get_next_sequence_number(output_dir, "") filename = f"{seq:05}-{res.all_seeds[(i-res.index_of_first_image)]}" - video_path_prefix = f"{p.outpath_samples}/AnimateDiff/{filename}" + video_path_prefix = output_dir / filename frame_list = self._add_reverse(params, frame_list) frame_list = self._interp(p, params, frame_list, filename) @@ -127,7 +129,7 @@ def _save( self, params: AnimateDiffProcess, frame_list: list, - video_path_prefix: str, + video_path_prefix: Path, res: Processed, index: int, ): @@ -136,15 +138,15 @@ def _save( infotext = res.infotexts[index] use_infotext = shared.opts.enable_pnginfo and infotext is not None if "PNG" in params.format and shared.opts.data.get("animatediff_save_to_custom", False): - Path(video_path_prefix).mkdir(exist_ok=True, parents=True) + video_path_prefix.mkdir(exist_ok=True, parents=True) for i, frame in enumerate(frame_list): - png_filename = f"{video_path_prefix}/{i:05}.png" + png_filename = video_path_prefix/f"{i:05}.png" png_info = PngImagePlugin.PngInfo() png_info.add_text('parameters', res.infotexts[0]) imageio.imwrite(png_filename, frame, pnginfo=png_info) if "GIF" in params.format: - video_path_gif = video_path_prefix + ".gif" + video_path_gif = str(video_path_prefix) + ".gif" video_paths.append(video_path_gif) if shared.opts.data.get("animatediff_optimize_gif_palette", False): try: @@ -209,7 +211,7 @@ def _save( self._optimize_gif(video_path_gif) if "MP4" in params.format: - video_path_mp4 = video_path_prefix + ".mp4" + video_path_mp4 = str(video_path_prefix) + ".mp4" video_paths.append(video_path_mp4) try: imageio.imwrite(video_path_mp4, video_array, fps=params.fps, codec="h264") @@ -222,13 +224,13 @@ def _save( imageio.imwrite(video_path_mp4, video_array, fps=params.fps, codec="h264") if "TXT" in params.format and res.images[index].info is not None: - video_path_txt = video_path_prefix + ".txt" + video_path_txt = str(video_path_prefix) + ".txt" with open(video_path_txt, "w", encoding="utf8") as file: file.write(f"{infotext}\n") if "WEBP" in params.format: if PIL.features.check('webp_anim'): - video_path_webp = video_path_prefix + ".webp" + video_path_webp = str(video_path_prefix) + ".webp" video_paths.append(video_path_webp) exif_bytes = b'' if use_infotext: @@ -272,5 +274,5 @@ def _encode_video_to_b64(self, paths): videos = [] for v_path in paths: with open(v_path, "rb") as video_file: - videos.append(base64.b64encode(video_file.read())) + videos.append(base64.b64encode(video_file.read()).decode("utf-8")) return videos From c9fc0c6d9348b899177c1bfd0f2cfe17a9eddaee Mon Sep 17 00:00:00 2001 From: Chengsong Zhang Date: Sun, 5 Nov 2023 01:41:18 -0500 Subject: [PATCH 31/54] api return b64 video --- scripts/animatediff_output.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/scripts/animatediff_output.py b/scripts/animatediff_output.py index 4e28a193..58c1236a 100644 --- a/scripts/animatediff_output.py +++ b/scripts/animatediff_output.py @@ -37,10 +37,7 @@ def output( video_paths += self._save(params, frame_list, video_path_prefix, res, i) if len(video_paths) > 0: - if p.is_api: - res.images = frame_list - else: - res.images = video_paths + res.images = (self._encode_video_to_b64(video_paths) + (frame_list if 'Frame' in params.format else [])) if p.is_api else video_paths def _add_reverse(self, params: AnimateDiffProcess, frame_list: list): @@ -112,11 +109,9 @@ def _interp( img.load() frame_list.append(img) - # if saving PNG, also save interpolated frames + # if saving PNG, enforce saving to custom folder if "PNG" in params.format: - save_interp_path = f"{p.outpath_samples}/AnimateDiff/interp" - os.makedirs(save_interp_path, exist_ok=True) - shutil.move(save_folder, save_interp_path) + params.force_save_to_custom = True # remove tmp folder try: shutil.rmtree(tmp_folder) @@ -137,7 +132,7 @@ def _save( video_array = [np.array(v) for v in frame_list] infotext = res.infotexts[index] use_infotext = shared.opts.enable_pnginfo and infotext is not None - if "PNG" in params.format and shared.opts.data.get("animatediff_save_to_custom", False): + if "PNG" in params.format and (shared.opts.data.get("animatediff_save_to_custom", False) or getattr(params, "force_save_to_custom", False)): video_path_prefix.mkdir(exist_ok=True, parents=True) for i, frame in enumerate(frame_list): png_filename = video_path_prefix/f"{i:05}.png" From 08d30b1a3c4100456016b18864be2336a877a34f Mon Sep 17 00:00:00 2001 From: Chengsong Zhang Date: Sun, 5 Nov 2023 01:42:41 -0600 Subject: [PATCH 32/54] hook encode_pil_to_base64 --- README.md | 2 +- scripts/animatediff_output.py | 21 +++++++++++++++++---- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 2381dcf2..be74b2d4 100644 --- a/README.md +++ b/README.md @@ -275,7 +275,7 @@ I thank researchers from [Shanghai AI Lab](https://www.shlab.org.cn/), especiall I also thank community developers, especially - [@zappityzap](https://github.com/zappityzap) who developed the majority of the [output features](https://github.com/continue-revolution/sd-webui-animatediff/blob/master/scripts/animatediff_output.py) -- [@TDS4874](https://github.com/TDS4874) and [@opparco](https://github.com/opparco) for resolving the grey issue which significantly improve the performance of this extension +- [@TDS4874](https://github.com/TDS4874) and [@opparco](https://github.com/opparco) for resolving the grey issue which significantly improve the performance - [@talesofai](https://github.com/talesofai) who developed i2v in [this forked repo](https://github.com/talesofai/AnimateDiff) - [@rkfg](https://github.com/rkfg) for developing GIF palette optimization diff --git a/scripts/animatediff_output.py b/scripts/animatediff_output.py index 58c1236a..b3eb07df 100644 --- a/scripts/animatediff_output.py +++ b/scripts/animatediff_output.py @@ -14,9 +14,10 @@ class AnimateDiffOutput: - def output( - self, p: StableDiffusionProcessing, res: Processed, params: AnimateDiffProcess - ): + api_encode_pil_to_base64_hooked = False + + + def output(self, p: StableDiffusionProcessing, res: Processed, params: AnimateDiffProcess): video_paths = [] logger.info("Merging images into GIF.") filename_generator = images.FilenameGenerator(p, p.seed, p.prompt, None) @@ -37,7 +38,19 @@ def output( video_paths += self._save(params, frame_list, video_path_prefix, res, i) if len(video_paths) > 0: - res.images = (self._encode_video_to_b64(video_paths) + (frame_list if 'Frame' in params.format else [])) if p.is_api else video_paths + if p.is_api: + if not AnimateDiffOutput.api_encode_pil_to_base64_hooked: + AnimateDiffOutput.api_encode_pil_to_base64_hooked = True + from modules.api import api + api_encode_pil_to_base64 = api.encode_pil_to_base64 + def hooked_encode_pil_to_base64(image): + if isinstance(image, str): + return image + return api_encode_pil_to_base64(image) + api.encode_pil_to_base64 = hooked_encode_pil_to_base64 + res.images = self._encode_video_to_b64(video_paths) + (frame_list if 'Frame' in params.format else []) + else: + res.images = video_paths def _add_reverse(self, params: AnimateDiffProcess, frame_list: list): From 1ad49bb6a5e471aa3908267bdfcc871dc313dd2e Mon Sep 17 00:00:00 2001 From: Chengsong Zhang Date: Sun, 5 Nov 2023 02:31:28 -0600 Subject: [PATCH 33/54] use date instead of datetime --- scripts/animatediff_output.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/scripts/animatediff_output.py b/scripts/animatediff_output.py index b3eb07df..542dfdf6 100644 --- a/scripts/animatediff_output.py +++ b/scripts/animatediff_output.py @@ -1,4 +1,5 @@ import base64 +import datetime from pathlib import Path import imageio.v3 as imageio @@ -20,8 +21,8 @@ class AnimateDiffOutput: def output(self, p: StableDiffusionProcessing, res: Processed, params: AnimateDiffProcess): video_paths = [] logger.info("Merging images into GIF.") - filename_generator = images.FilenameGenerator(p, p.seed, p.prompt, None) - output_dir = Path(f"{p.outpath_samples}/AnimateDiff/{filename_generator.datetime()}") + date = datetime.datetime.now().strftime('%Y-%m-%d') + output_dir = Path(f"{p.outpath_samples}/AnimateDiff/{date}") output_dir.mkdir(parents=True, exist_ok=True) step = params.video_length if params.video_length > params.batch_size else params.batch_size for i in range(res.index_of_first_image, len(res.images), step): From c86da9cdf2d9ac38ea5a8584d12fc831759cfcb6 Mon Sep 17 00:00:00 2001 From: Chengsong Zhang Date: Sun, 5 Nov 2023 04:34:40 -0600 Subject: [PATCH 34/54] remove unnecessary code for ipadapter --- scripts/animatediff_infv2v.py | 56 +++++++---------------------------- 1 file changed, 10 insertions(+), 46 deletions(-) diff --git a/scripts/animatediff_infv2v.py b/scripts/animatediff_infv2v.py index 8dc5532d..90b77b72 100644 --- a/scripts/animatediff_infv2v.py +++ b/scripts/animatediff_infv2v.py @@ -100,26 +100,10 @@ def hack(self, params: AnimateDiffProcess): def mm_cn_select(context: List[int]): # take control images for current context. - if cn_script is not None and cn_script.latest_network is not None: + if cn_script and cn_script.latest_network: from scripts.hook import ControlModelType for control in cn_script.latest_network.control_params: - if control.control_model_type == ControlModelType.IPAdapter: - ip_adapter_key = list(control.hint_cond)[0] - if ip_adapter_key == "image_embeds": - if control.hint_cond[ip_adapter_key].shape[0] > len(context): - control.hint_cond_backup = control.hint_cond[ip_adapter_key] - control.hint_cond[ip_adapter_key] = control.hint_cond[ip_adapter_key][context] - if control.hr_hint_cond is not None and control.hr_hint_cond[ip_adapter_key].shape[0] > len(context): - control.hr_hint_cond_backup = control.hr_hint_cond[ip_adapter_key] - control.hr_hint_cond[ip_adapter_key] = control.hr_hint_cond[ip_adapter_key][context] - elif ip_adapter_key == "hidden_states": - if control.hint_cond[ip_adapter_key][-2].shape[0] > len(context): - control.hint_cond_backup = control.hint_cond[ip_adapter_key][-2] - control.hint_cond[ip_adapter_key][-2] = control.hint_cond[ip_adapter_key][-2][context] - if control.hr_hint_cond is not None and control.hr_hint_cond[ip_adapter_key][-2].shape[0] > len(context): - control.hr_hint_cond_backup = control.hr_hint_cond[ip_adapter_key][-2] - control.hr_hint_cond[ip_adapter_key][-2] = control.hr_hint_cond[ip_adapter_key][-2][context] - else: + if control.control_model_type != ControlModelType.IPAdapter: if control.hint_cond.shape[0] > len(context): control.hint_cond_backup = control.hint_cond control.hint_cond = control.hint_cond[context] @@ -139,36 +123,16 @@ def mm_cn_select(context: List[int]): def mm_cn_restore(context: List[int]): # restore control images for next context - if cn_script is not None and cn_script.latest_network is not None: + if cn_script and cn_script.latest_network: from scripts.hook import ControlModelType for control in cn_script.latest_network.control_params: - if getattr(control, "hint_cond_backup", None) is not None: - if control.control_model_type == ControlModelType.IPAdapter: - ip_adapter_key = list(control.hint_cond_backup)[0] - if ip_adapter_key == "image_embeds": - control.hint_cond_backup[context] = control.hint_cond[ip_adapter_key] - control.hint_cond[ip_adapter_key] = control.hint_cond_backup - elif ip_adapter_key == "hidden_states": - control.hint_cond_backup[context] = control.hint_cond[ip_adapter_key][-2] - control.hint_cond[ip_adapter_key][-2] = control.hint_cond_backup - else: - control.hint_cond_backup[context] = control.hint_cond - control.hint_cond = control.hint_cond_backup - if control.hr_hint_cond is not None and getattr(control, "hr_hint_cond_backup", None) is not None: - if control.control_model_type == ControlModelType.IPAdapter: - ip_adapter_key = list(control.hr_hint_cond_backup)[0] - if ip_adapter_key == "image_embeds": - control.hr_hint_cond_backup[ip_adapter_key][context] = control.hr_hint_cond[ip_adapter_key] - control.hr_hint_cond[ip_adapter_key] = control.hr_hint_cond_backup[ip_adapter_key] - elif ip_adapter_key == "hidden_states": - control.hr_hint_cond_backup[context] = control.hr_hint_cond[ip_adapter_key][-2] - control.hr_hint_cond[ip_adapter_key][-2] = control.hr_hint_cond_backup - else: - control.hr_hint_cond_backup[context] = control.hr_hint_cond - control.hr_hint_cond = control.hr_hint_cond_backup - if control.control_model_type == ControlModelType.IPAdapter and getattr(control.control_model, "image_emb_backup", None) is not None: - # control.control_model.image_emb_backup[context] = control.control_model.image_emb - # control.control_model.uncond_image_emb_backup[context] = control.control_model.uncond_image_emb + if getattr(control, "hint_cond_backup", None) and control.control_model_type != ControlModelType.IPAdapter: + control.hint_cond_backup[context] = control.hint_cond + control.hint_cond = control.hint_cond_backup + if control.hr_hint_cond and getattr(control, "hr_hint_cond_backup", None) and control.control_model_type != ControlModelType.IPAdapter: + control.hr_hint_cond_backup[context] = control.hr_hint_cond + control.hr_hint_cond = control.hr_hint_cond_backup + if control.control_model_type == ControlModelType.IPAdapter and getattr(control.control_model, "image_emb_backup", None): control.control_model.image_emb = control.control_model.image_emb_backup control.control_model.uncond_image_emb = control.control_model.uncond_image_emb_backup if control.control_model_type == ControlModelType.Controlllite: From 54ba5dba09c61cb30531c52dc9bff9d6fc38a14f Mon Sep 17 00:00:00 2001 From: Chengsong Zhang Date: Sun, 5 Nov 2023 17:48:58 -0600 Subject: [PATCH 35/54] vram optim --- scripts/animatediff_cn.py | 12 ++++++------ scripts/animatediff_infv2v.py | 33 +++++++++++++++++++-------------- 2 files changed, 25 insertions(+), 20 deletions(-) diff --git a/scripts/animatediff_cn.py b/scripts/animatediff_cn.py index 7ad2f49f..432c7a6f 100644 --- a/scripts/animatediff_cn.py +++ b/scripts/animatediff_cn.py @@ -412,21 +412,21 @@ def set_numpy_seed(p: processing.StableDiffusionProcessing) -> Optional[int]: if control_model_type == ControlModelType.IPAdapter: if model_net.is_plus: - controls_ipadapter['hidden_states'].append(control['hidden_states'][-2]) + controls_ipadapter['hidden_states'].append(control['hidden_states'][-2].cpu()) else: - controls_ipadapter['image_embeds'].append(control['image_embeds']) + controls_ipadapter['image_embeds'].append(control['image_embeds'].cpu()) if hr_control is not None: if model_net.is_plus: - hr_controls_ipadapter['hidden_states'].append(hr_control['hidden_states'][-2]) + hr_controls_ipadapter['hidden_states'].append(hr_control['hidden_states'][-2].cpu()) else: - hr_controls_ipadapter['image_embeds'].append(hr_control['image_embeds']) + hr_controls_ipadapter['image_embeds'].append(hr_control['image_embeds'].cpu()) else: hr_controls_ipadapter = None hr_controls = None else: - controls.append(control) + controls.append(control.cpu()) if hr_control is not None: - hr_controls.append(hr_control) + hr_controls.append(hr_control.cpu()) else: hr_controls = None diff --git a/scripts/animatediff_infv2v.py b/scripts/animatediff_infv2v.py index 90b77b72..c11b422e 100644 --- a/scripts/animatediff_infv2v.py +++ b/scripts/animatediff_infv2v.py @@ -103,19 +103,23 @@ def mm_cn_select(context: List[int]): if cn_script and cn_script.latest_network: from scripts.hook import ControlModelType for control in cn_script.latest_network.control_params: - if control.control_model_type != ControlModelType.IPAdapter: + if control.control_model_type not in [ControlModelType.IPAdapter, ControlModelType.Controlllite]: if control.hint_cond.shape[0] > len(context): control.hint_cond_backup = control.hint_cond control.hint_cond = control.hint_cond[context] - if control.hr_hint_cond is not None and control.hr_hint_cond.shape[0] > len(context): - control.hr_hint_cond_backup = control.hr_hint_cond - control.hr_hint_cond = control.hr_hint_cond[context] - if control.control_model_type == ControlModelType.IPAdapter and control.control_model.image_emb.shape[0] > len(context): + control.hint_cond = control.hint_cond.to(device=shared.device) + if control.hr_hint_cond: + if control.hr_hint_cond.shape[0] > len(context): + control.hr_hint_cond_backup = control.hr_hint_cond + control.hr_hint_cond = control.hr_hint_cond[context] + control.hr_hint_cond = control.hr_hint_cond.to(device=shared.device) + # IPAdapter and Controlllite are always on CPU. + elif control.control_model_type == ControlModelType.IPAdapter and control.control_model.image_emb.shape[0] > len(context): control.control_model.image_emb_backup = control.control_model.image_emb control.control_model.image_emb = control.control_model.image_emb[context] control.control_model.uncond_image_emb_backup = control.control_model.uncond_image_emb control.control_model.uncond_image_emb = control.control_model.uncond_image_emb[context] - if control.control_model_type == ControlModelType.Controlllite: + elif control.control_model_type == ControlModelType.Controlllite: for module in control.control_model.modules.values(): if module.cond_image.shape[0] > len(context): module.cond_image_backup = module.cond_image @@ -126,16 +130,17 @@ def mm_cn_restore(context: List[int]): if cn_script and cn_script.latest_network: from scripts.hook import ControlModelType for control in cn_script.latest_network.control_params: - if getattr(control, "hint_cond_backup", None) and control.control_model_type != ControlModelType.IPAdapter: - control.hint_cond_backup[context] = control.hint_cond - control.hint_cond = control.hint_cond_backup - if control.hr_hint_cond and getattr(control, "hr_hint_cond_backup", None) and control.control_model_type != ControlModelType.IPAdapter: - control.hr_hint_cond_backup[context] = control.hr_hint_cond - control.hr_hint_cond = control.hr_hint_cond_backup - if control.control_model_type == ControlModelType.IPAdapter and getattr(control.control_model, "image_emb_backup", None): + if control.control_model_type not in [ControlModelType.IPAdapter, ControlModelType.Controlllite]: + if getattr(control, "hint_cond_backup", None): + control.hint_cond_backup[context] = control.hint_cond.to(device="cpu") + control.hint_cond = control.hint_cond_backup + if control.hr_hint_cond and getattr(control, "hr_hint_cond_backup", None): + control.hr_hint_cond_backup[context] = control.hr_hint_cond.to(device="cpu") + control.hr_hint_cond = control.hr_hint_cond_backup + elif control.control_model_type == ControlModelType.IPAdapter and getattr(control.control_model, "image_emb_backup", None): control.control_model.image_emb = control.control_model.image_emb_backup control.control_model.uncond_image_emb = control.control_model.uncond_image_emb_backup - if control.control_model_type == ControlModelType.Controlllite: + elif control.control_model_type == ControlModelType.Controlllite: for module in control.control_model.modules.values(): if module.cond_image.shape[0] > len(context): module.set_cond_image(module.cond_image_backup) From f050507fb2bf4410de472e83675109d9270ebdca Mon Sep 17 00:00:00 2001 From: Chengsong Zhang Date: Sun, 5 Nov 2023 18:50:23 -0600 Subject: [PATCH 36/54] recover from assertion error such as OOM without the need to re-start --- scripts/animatediff_cn.py | 33 ++++++++++++++++++++------------- scripts/animatediff_infv2v.py | 11 ++++++++--- scripts/animatediff_lora.py | 15 ++++++++++----- scripts/animatediff_mm.py | 16 +++++++++++++--- 4 files changed, 51 insertions(+), 24 deletions(-) diff --git a/scripts/animatediff_cn.py b/scripts/animatediff_cn.py index 432c7a6f..1edf625e 100644 --- a/scripts/animatediff_cn.py +++ b/scripts/animatediff_cn.py @@ -22,12 +22,11 @@ class AnimateDiffControl: + original_processing_process_images_hijack = None + original_controlnet_main_entry = None + original_postprocess_batch = None def __init__(self, p: StableDiffusionProcessing, prompt_scheduler: AnimateDiffPromptSchedule): - self.original_processing_process_images_hijack = None - self.original_img2img_process_batch_hijack = None - self.original_controlnet_main_entry = None - self.original_postprocess_batch = None try: from scripts.external_code import find_cn_script self.cn_script = find_cn_script(p.scripts) @@ -118,15 +117,19 @@ def hacked_processing_process_images_hijack(self, p: StableDiffusionProcessing, update_infotext(p, params) return getattr(processing, '__controlnet_original_process_images_inner')(p, *args, **kwargs) - self.original_processing_process_images_hijack = BatchHijack.processing_process_images_hijack + if AnimateDiffControl.original_processing_process_images_hijack is not None: + logger.info('BatchHijack already hacked.') + return + + AnimateDiffControl.original_processing_process_images_hijack = BatchHijack.processing_process_images_hijack BatchHijack.processing_process_images_hijack = hacked_processing_process_images_hijack processing.process_images_inner = instance.processing_process_images_hijack def restore_batchhijack(self): from scripts.batch_hijack import BatchHijack, instance - BatchHijack.processing_process_images_hijack = self.original_processing_process_images_hijack - self.original_processing_process_images_hijack = None + BatchHijack.processing_process_images_hijack = AnimateDiffControl.original_processing_process_images_hijack + AnimateDiffControl.original_processing_process_images_hijack = None processing.process_images_inner = instance.processing_process_images_hijack @@ -599,17 +602,21 @@ def hacked_postprocess_batch(self, p, *args, **kwargs): images[i] = post_processor(images[i], i) return - self.original_controlnet_main_entry = self.cn_script.controlnet_main_entry - self.original_postprocess_batch = self.cn_script.postprocess_batch + if AnimateDiffControl.original_controlnet_main_entry is not None: + logger.info('ControlNet Main Entry already hacked.') + return + + AnimateDiffControl.original_controlnet_main_entry = self.cn_script.controlnet_main_entry + AnimateDiffControl.original_postprocess_batch = self.cn_script.postprocess_batch self.cn_script.controlnet_main_entry = MethodType(hacked_main_entry, self.cn_script) self.cn_script.postprocess_batch = MethodType(hacked_postprocess_batch, self.cn_script) def restore_cn(self): - self.cn_script.controlnet_main_entry = self.original_controlnet_main_entry - self.original_controlnet_main_entry = None - self.cn_script.postprocess_batch = self.original_postprocess_batch - self.original_postprocess_batch = None + self.cn_script.controlnet_main_entry = AnimateDiffControl.original_controlnet_main_entry + AnimateDiffControl.original_controlnet_main_entry = None + self.cn_script.postprocess_batch = AnimateDiffControl.original_postprocess_batch + AnimateDiffControl.original_postprocess_batch = None def hack(self, params: AnimateDiffProcess): diff --git a/scripts/animatediff_infv2v.py b/scripts/animatediff_infv2v.py index c11b422e..ea78e818 100644 --- a/scripts/animatediff_infv2v.py +++ b/scripts/animatediff_infv2v.py @@ -16,9 +16,9 @@ class AnimateDiffInfV2V: + cfg_original_forward = None def __init__(self, p, prompt_scheduler: AnimateDiffPromptSchedule): - self.cfg_original_forward = None try: from scripts.external_code import find_cn_script self.cn_script = find_cn_script(p.scripts) @@ -93,8 +93,12 @@ def get_unsorted_index(lst): def hack(self, params: AnimateDiffProcess): + if AnimateDiffInfV2V.cfg_original_forward is not None: + logger.info("CFGDenoiser already hacked") + return + logger.info(f"Hacking CFGDenoiser forward function.") - self.cfg_original_forward = CFGDenoiser.forward + AnimateDiffInfV2V.cfg_original_forward = CFGDenoiser.forward cn_script = self.cn_script prompt_scheduler = self.prompt_scheduler @@ -310,4 +314,5 @@ def mm_cfg_forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image def restore(self): logger.info(f"Restoring CFGDenoiser forward function.") - CFGDenoiser.forward = self.cfg_original_forward + CFGDenoiser.forward = AnimateDiffInfV2V.cfg_original_forward + AnimateDiffInfV2V.cfg_original_forward = None diff --git a/scripts/animatediff_lora.py b/scripts/animatediff_lora.py index 2a5d0236..c978645b 100644 --- a/scripts/animatediff_lora.py +++ b/scripts/animatediff_lora.py @@ -10,20 +10,24 @@ sys.path.append(f"{extensions_builtin_dir}/Lora") class AnimateDiffLora: + original_load_network = None def __init__(self, v2: bool): - self.original_load_network = None self.v2 = v2 def hack(self): if not self.v2: return - logger.info("Hacking lora to support motion lora") + if AnimateDiffLora.original_load_network is not None: + logger.info("AnimateDiff LoRA already hacked") + return + + logger.info("Hacking loral to support motion lora") import network import networks - self.original_load_network = networks.load_network - original_load_network = self.original_load_network + AnimateDiffLora.original_load_network = networks.load_network + original_load_network = AnimateDiffLora.original_load_network def mm_load_network(name, network_on_disk): @@ -70,4 +74,5 @@ def restore(self): if self.v2: logger.info("Restoring hacked lora") import networks - networks.load_network = self.original_load_network + networks.load_network = AnimateDiffLora.original_load_network + AnimateDiffLora.original_load_network = None diff --git a/scripts/animatediff_mm.py b/scripts/animatediff_mm.py index 08e26158..0c191b75 100644 --- a/scripts/animatediff_mm.py +++ b/scripts/animatediff_mm.py @@ -11,6 +11,7 @@ class AnimateDiffMM: + mm_injected = False def __init__(self): self.mm: MotionWrapper = None @@ -45,13 +46,15 @@ def _load(self, model_name): def inject(self, sd_model, model_name="mm_sd_v15.ckpt"): + if AnimateDiffMM.mm_injected: + logger.info("Motion module already injected. Trying to restore.") + self.restore(sd_model) + unet = sd_model.model.diffusion_model self._load(model_name) inject_sdxl = sd_model.is_sdxl or self.mm.is_sdxl sd_ver = "SDXL" if sd_model.is_sdxl else "SD1.5" - if sd_model.is_sdxl != self.mm.is_sdxl: - logger.warn(f"Motion module incompatible with SD. You are using {sd_ver} with {self.mm.mm_type}. " - f"You will see an error afterwards. Even if the injection and inference seem to go on, you will get bad results.") + assert sd_model.is_sdxl == self.mm.is_sdxl, f"Motion module incompatible with SD. You are using {sd_ver} with {self.mm.mm_type}." if self.mm.is_v2: logger.info(f"Injecting motion module {model_name} into {sd_ver} UNet middle block.") @@ -94,6 +97,7 @@ def groupnorm32_mm_forward(self, x): self._set_ddim_alpha(sd_model) self._set_layer_mapping(sd_model) + AnimateDiffMM.mm_injected = True logger.info(f"Injection finished.") @@ -102,11 +106,13 @@ def restore(self, sd_model): sd_ver = "SDXL" if sd_model.is_sdxl else "SD1.5" self._restore_ddim_alpha(sd_model) unet = sd_model.model.diffusion_model + logger.info(f"Removing motion module from {sd_ver} UNet input blocks.") for unet_idx in [1, 2, 4, 5, 7, 8, 10, 11]: if inject_sdxl and unet_idx >= 9: break unet.input_blocks[unet_idx].pop(-1) + logger.info(f"Removing motion module from {sd_ver} UNet output blocks.") for unet_idx in range(12): if inject_sdxl and unet_idx >= 9: @@ -115,6 +121,7 @@ def restore(self, sd_model): unet.output_blocks[unet_idx].pop(-2) else: unet.output_blocks[unet_idx].pop(-1) + if self.mm.is_v2: logger.info(f"Removing motion module from {sd_ver} UNet middle block.") unet.middle_block.pop(-2) @@ -126,6 +133,8 @@ def restore(self, sd_model): from ldm.modules.diffusionmodules.util import GroupNorm32 GroupNorm32.forward = self.gn32_original_forward self.gn32_original_forward = None + + AnimateDiffMM.mm_injected = False logger.info(f"Removal finished.") if shared.cmd_opts.lowvram: self.unload() @@ -155,6 +164,7 @@ def _set_layer_mapping(self, sd_model): sd_model.network_layer_mapping[name] = module module.network_layer_name = name + def _restore_ddim_alpha(self, sd_model): logger.info(f"Restoring DDIM alpha.") sd_model.alphas_cumprod = self.prev_alpha_cumprod From 9f5fbb5f9fdd10de9a95e57fd9d96e4ce9099c80 Mon Sep 17 00:00:00 2001 From: Chengsong Zhang Date: Sun, 5 Nov 2023 19:55:41 -0600 Subject: [PATCH 37/54] bugfix --- scripts/animatediff_infv2v.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/scripts/animatediff_infv2v.py b/scripts/animatediff_infv2v.py index ea78e818..b13f6ba1 100644 --- a/scripts/animatediff_infv2v.py +++ b/scripts/animatediff_infv2v.py @@ -112,7 +112,7 @@ def mm_cn_select(context: List[int]): control.hint_cond_backup = control.hint_cond control.hint_cond = control.hint_cond[context] control.hint_cond = control.hint_cond.to(device=shared.device) - if control.hr_hint_cond: + if control.hr_hint_cond is not None: if control.hr_hint_cond.shape[0] > len(context): control.hr_hint_cond_backup = control.hr_hint_cond control.hr_hint_cond = control.hr_hint_cond[context] @@ -135,13 +135,13 @@ def mm_cn_restore(context: List[int]): from scripts.hook import ControlModelType for control in cn_script.latest_network.control_params: if control.control_model_type not in [ControlModelType.IPAdapter, ControlModelType.Controlllite]: - if getattr(control, "hint_cond_backup", None): + if getattr(control, "hint_cond_backup", None) is not None: control.hint_cond_backup[context] = control.hint_cond.to(device="cpu") control.hint_cond = control.hint_cond_backup - if control.hr_hint_cond and getattr(control, "hr_hint_cond_backup", None): + if control.hr_hint_cond is not None and getattr(control, "hr_hint_cond_backup", None) is not None: control.hr_hint_cond_backup[context] = control.hr_hint_cond.to(device="cpu") control.hr_hint_cond = control.hr_hint_cond_backup - elif control.control_model_type == ControlModelType.IPAdapter and getattr(control.control_model, "image_emb_backup", None): + elif control.control_model_type == ControlModelType.IPAdapter and getattr(control.control_model, "image_emb_backup", None) is not None: control.control_model.image_emb = control.control_model.image_emb_backup control.control_model.uncond_image_emb = control.control_model.uncond_image_emb_backup elif control.control_model_type == ControlModelType.Controlllite: From 47077bca2e975713fe22d7d74c68896214295b4c Mon Sep 17 00:00:00 2001 From: Chengsong Zhang Date: Sun, 5 Nov 2023 20:27:47 -0600 Subject: [PATCH 38/54] add todo to pr a1111 --- scripts/animatediff_i2ibatch.py | 1 + scripts/animatediff_output.py | 1 + 2 files changed, 2 insertions(+) diff --git a/scripts/animatediff_i2ibatch.py b/scripts/animatediff_i2ibatch.py index 1788af58..836d83ad 100644 --- a/scripts/animatediff_i2ibatch.py +++ b/scripts/animatediff_i2ibatch.py @@ -24,6 +24,7 @@ class AnimateDiffI2IBatch: def hack(self): + # TODO: PR this hack to A1111 logger.info("Hacking i2i-batch.") original_img2img_process_batch = img2img.process_batch diff --git a/scripts/animatediff_output.py b/scripts/animatediff_output.py index 542dfdf6..78059742 100644 --- a/scripts/animatediff_output.py +++ b/scripts/animatediff_output.py @@ -41,6 +41,7 @@ def output(self, p: StableDiffusionProcessing, res: Processed, params: AnimateDi if len(video_paths) > 0: if p.is_api: if not AnimateDiffOutput.api_encode_pil_to_base64_hooked: + # TODO: remove this hook when WebUI is updated to v1.7.0 AnimateDiffOutput.api_encode_pil_to_base64_hooked = True from modules.api import api api_encode_pil_to_base64 = api.encode_pil_to_base64 From fcda9934e307e3cca24525af3b21c739ad41a78d Mon Sep 17 00:00:00 2001 From: Chengsong Zhang Date: Mon, 6 Nov 2023 18:39:11 -0600 Subject: [PATCH 39/54] fix lllite, fix absolute path, fix infotext --- scripts/animatediff.py | 7 +++---- scripts/animatediff_i2ibatch.py | 1 - scripts/animatediff_infv2v.py | 2 +- scripts/animatediff_mm.py | 12 ++++++++---- scripts/animatediff_output.py | 2 +- scripts/animatediff_prompt.py | 4 ++++ 6 files changed, 17 insertions(+), 11 deletions(-) diff --git a/scripts/animatediff.py b/scripts/animatediff.py index a5efe0b5..76e6d434 100644 --- a/scripts/animatediff.py +++ b/scripts/animatediff.py @@ -39,8 +39,7 @@ def show(self, is_img2img): def ui(self, is_img2img): - model_dir = shared.opts.data.get("animatediff_model_path", os.path.join(script_dir, "model")) - return (AnimateDiffUiGroup().render(is_img2img, model_dir),) + return (AnimateDiffUiGroup().render(is_img2img, motion_module.get_model_dir()),) def before_process(self, p: StableDiffusionProcessing, params: AnimateDiffProcess): @@ -94,7 +93,7 @@ def on_ui_settings(): shared.opts.add_option( "animatediff_model_path", shared.OptionInfo( - os.path.join(script_dir, "model"), + None, "Path to save AnimateDiff motion modules", gr.Textbox, section=section, @@ -144,7 +143,7 @@ def on_ui_settings(): "animatediff_save_to_custom", shared.OptionInfo( False, - "Save frames to stable-diffusion-webui/outputs/{ txt|img }2img-images/AnimateDiff/{gif filename}/ " + "Save frames to stable-diffusion-webui/outputs/{ txt|img }2img-images/AnimateDiff/{gif filename}/{date} " "instead of stable-diffusion-webui/outputs/{ txt|img }2img-images/{date}/.", gr.Checkbox, section=section diff --git a/scripts/animatediff_i2ibatch.py b/scripts/animatediff_i2ibatch.py index 836d83ad..92ffd0cc 100644 --- a/scripts/animatediff_i2ibatch.py +++ b/scripts/animatediff_i2ibatch.py @@ -300,5 +300,4 @@ def cap_init_image(self, p: StableDiffusionProcessingImg2Img, params): params.batch_size = len(p.init_images) - animatediff_i2ibatch = AnimateDiffI2IBatch() diff --git a/scripts/animatediff_infv2v.py b/scripts/animatediff_infv2v.py index b13f6ba1..aa547d39 100644 --- a/scripts/animatediff_infv2v.py +++ b/scripts/animatediff_infv2v.py @@ -146,7 +146,7 @@ def mm_cn_restore(context: List[int]): control.control_model.uncond_image_emb = control.control_model.uncond_image_emb_backup elif control.control_model_type == ControlModelType.Controlllite: for module in control.control_model.modules.values(): - if module.cond_image.shape[0] > len(context): + if getattr(module, "cond_image_backup", None) is not None: module.set_cond_image(module.cond_image_backup) def mm_sd_forward(self, x_in, sigma_in, cond_in, image_cond_in, make_condition_dict): diff --git a/scripts/animatediff_mm.py b/scripts/animatediff_mm.py index 0c191b75..91720601 100644 --- a/scripts/animatediff_mm.py +++ b/scripts/animatediff_mm.py @@ -24,11 +24,15 @@ def set_script_dir(self, script_dir): self.script_dir = script_dir + def get_model_dir(self): + model_dir = shared.opts.data.get("animatediff_model_path", os.path.join(self.script_dir, "model")) + if not model_dir: + model_dir = os.path.join(self.script_dir, "model") + return model_dir + + def _load(self, model_name): - model_path = os.path.join( - shared.opts.data.get("animatediff_model_path", os.path.join(self.script_dir, "model")), - model_name, - ) + model_path = os.path.join(self.get_model_dir(), model_name) if not os.path.isfile(model_path): raise RuntimeError("Please download models manually.") if self.mm is None or self.mm.mm_name != model_name: diff --git a/scripts/animatediff_output.py b/scripts/animatediff_output.py index 78059742..dc26ceef 100644 --- a/scripts/animatediff_output.py +++ b/scripts/animatediff_output.py @@ -152,7 +152,7 @@ def _save( for i, frame in enumerate(frame_list): png_filename = video_path_prefix/f"{i:05}.png" png_info = PngImagePlugin.PngInfo() - png_info.add_text('parameters', res.infotexts[0]) + png_info.add_text('parameters', infotext) imageio.imwrite(png_filename, frame, pnginfo=png_info) if "GIF" in params.format: diff --git a/scripts/animatediff_prompt.py b/scripts/animatediff_prompt.py index 35da88a4..bba96255 100644 --- a/scripts/animatediff_prompt.py +++ b/scripts/animatediff_prompt.py @@ -24,6 +24,10 @@ def save_infotext_txt(self, res: Processed): parts = res.info.split('\nNegative prompt: ', 1) if len(parts) > 1: res.info = f"{self.original_prompt}\nNegative prompt: {parts[1]}" + for i in range(len(res.infotexts)): + parts = res.infotexts[i].split('\nNegative prompt: ', 1) + if len(parts) > 1: + res.infotexts[i] = f"{self.original_prompt}\nNegative prompt: {parts[1]}" write_params_txt(res.info) From eea914578d81a685d5020f317c51561e1a6d4448 Mon Sep 17 00:00:00 2001 From: Chengsong Zhang Date: Mon, 6 Nov 2023 20:25:04 -0600 Subject: [PATCH 40/54] fix ap[i --- scripts/animatediff.py | 14 +++++++------- scripts/animatediff_output.py | 1 + 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/scripts/animatediff.py b/scripts/animatediff.py index 76e6d434..b1e751a1 100644 --- a/scripts/animatediff.py +++ b/scripts/animatediff.py @@ -1,5 +1,3 @@ -import os - import gradio as gr from modules import script_callbacks, scripts, shared from modules.processing import (Processed, StableDiffusionProcessing, @@ -43,7 +41,9 @@ def ui(self, is_img2img): def before_process(self, p: StableDiffusionProcessing, params: AnimateDiffProcess): - if isinstance(params, dict): params = AnimateDiffProcess(**params) + if p.is_api and isinstance(params, dict): + self.ad_params = AnimateDiffProcess(**params) + params = self.ad_params if params.enable: logger.info("AnimateDiff process start.") params.set_p(p) @@ -59,25 +59,25 @@ def before_process(self, p: StableDiffusionProcessing, params: AnimateDiffProces def before_process_batch(self, p: StableDiffusionProcessing, params: AnimateDiffProcess, **kwargs): - if isinstance(params, dict): params = AnimateDiffProcess(**params) + if p.is_api and isinstance(params, dict): params = self.ad_params if params.enable and isinstance(p, StableDiffusionProcessingImg2Img) and not hasattr(p, '_animatediff_i2i_batch'): AnimateDiffI2VLatent().randomize(p, params) def postprocess_batch_list(self, p: StableDiffusionProcessing, pp: PostprocessBatchListArgs, params: AnimateDiffProcess, **kwargs): - if isinstance(params, dict): params = AnimateDiffProcess(**params) + if p.is_api and isinstance(params, dict): params = self.ad_params if params.enable: self.prompt_scheduler.save_infotext_img(p) def postprocess_image(self, p: StableDiffusionProcessing, pp: PostprocessImageArgs, params: AnimateDiffProcess, *args): - if isinstance(params, dict): params = AnimateDiffProcess(**params) + if p.is_api and isinstance(params, dict): params = self.ad_params if params.enable and isinstance(p, StableDiffusionProcessingImg2Img) and hasattr(p, '_animatediff_paste_to_full'): p.paste_to = p._animatediff_paste_to_full[p.batch_index] def postprocess(self, p: StableDiffusionProcessing, res: Processed, params: AnimateDiffProcess): - if isinstance(params, dict): params = AnimateDiffProcess(**params) + if p.is_api and isinstance(params, dict): params = self.ad_params if params.enable: self.prompt_scheduler.save_infotext_txt(res) self.cn_hacker.restore() diff --git a/scripts/animatediff_output.py b/scripts/animatediff_output.py index dc26ceef..e3649684 100644 --- a/scripts/animatediff_output.py +++ b/scripts/animatediff_output.py @@ -42,6 +42,7 @@ def output(self, p: StableDiffusionProcessing, res: Processed, params: AnimateDi if p.is_api: if not AnimateDiffOutput.api_encode_pil_to_base64_hooked: # TODO: remove this hook when WebUI is updated to v1.7.0 + logger.info("Hooking api.encode_pil_to_base64 to encode video to base64") AnimateDiffOutput.api_encode_pil_to_base64_hooked = True from modules.api import api api_encode_pil_to_base64 = api.encode_pil_to_base64 From 589ac877fbfbd9b0a5bc790621f6857c60dae2b2 Mon Sep 17 00:00:00 2001 From: Chengsong Zhang Date: Mon, 6 Nov 2023 23:46:06 -0600 Subject: [PATCH 41/54] readme --- README.md | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index be74b2d4..54fac07a 100644 --- a/README.md +++ b/README.md @@ -53,7 +53,7 @@ You might also be interested in another extension I created: [Segment Anything f - `2023/10/21`: [v1.9.4](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.9.4): Save prompt travel to output images, `Reverse` merged to `Closed loop` (See [WebUI Parameters](#webui-parameters)), remove `TimestepEmbedSequential` hijack, remove `hints.js`, better explanation of several context-related parameters. - `2023/10/25`: [v1.10.0](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.10.0): Support img2img batch. You need ControlNet installed to make it work properly (you do not need to enable ControlNet). See [ControlNet V2V](#controlnet-v2v) for more information. - `2023/10/29`: [v1.11.0](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.11.0): Support [HotShot-XL](https://github.com/hotshotco/Hotshot-XL) for SDXL. See [HotShot-XL](#hotshot-xl) for more information. -- `2023/1?/??`: [v1.11.1](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.11.1): TODO +- `2023/11/06`: [v1.11.1](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.11.1): optimize VRAM to support any number of control images for ControlNet V2V, patch [encode_pil_to_base64](https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/master/modules/api/api.py#L104-L133) to support api return a video, save frames to `AnimateDIff/yy-mm-dd/`, recover from assertion error without restart. For future update plan, please query [here](https://github.com/continue-revolution/sd-webui-animatediff/pull/224). @@ -68,16 +68,16 @@ For future update plan, please query [here](https://github.com/continue-revoluti 1. Go to txt2img if you want to try txt2gif and img2img if you want to try img2gif. 1. Choose an SD1.5 checkpoint, write prompts, set configurations such as image width/height. If you want to generate multiple GIFs at once, please [change batch number, instead of batch size](#batch-size). 1. Enable AnimateDiff extension, set up [each parameter](#webui-parameters), then click `Generate`. -1. You should see the output GIF on the output gallery. You can access GIF output at `stable-diffusion-webui/outputs/{txt2img or img2img}-images/AnimateDiff`. You can also access image frames at `stable-diffusion-webui/outputs/{txt2img or img2img}-images/{date}`. You may choose to save frames for each generation into separate directories in `Settings/AnimateDiff`. +1. You should see the output GIF on the output gallery. You can access GIF output at `stable-diffusion-webui/outputs/{txt2img or img2img}-images/AnimateDiff/{yy-mm-dd}`. You can also access image frames at `stable-diffusion-webui/outputs/{txt2img or img2img}-images/{yy-mm-dd}`. You may choose to save frames for each generation into separate directories in `Settings/AnimateDiff`. ### API -Just like how you use ControlNet. Here is a sample. Due to the limitation of WebUI, you will not be able to get a video, but only a list of generated frames. You will have to view GIF in your file system, as mentioned at [WebUI](#webui) item 4. For most up-to-date parameters, please read [here](https://github.com/continue-revolution/sd-webui-animatediff/blob/master/scripts/animatediff_ui.py#L26). +It is quite similar to the way you use ControlNet. API will return a video in base64 format. In `format`, `PDF` means to save frames to your file system without returning all the frames. If you want your API to return all frames, please add `Frame` to `format` list. For most up-to-date parameters, please read [here](https://github.com/continue-revolution/sd-webui-animatediff/blob/master/scripts/animatediff_ui.py#L26). ``` 'alwayson_scripts': { 'AnimateDiff': { 'args': [{ 'model': 'mm_sd_v15_v2.ckpt', # Motion module - 'format': ['GIF'], # Save format, 'GIF' | 'MP4' | 'PNG' | 'WEBP' | 'TXT' + 'format': ['GIF'], # Save format, 'GIF' | 'MP4' | 'PNG' | 'WEBP' | 'TXT' | 'Frame' 'enable': True, # Enable AnimateDiff 'video_length': 16, # Number of frames 'fps': 8, # FPS @@ -106,6 +106,7 @@ Just like how you use ControlNet. Here is a sample. Due to the limitation of Web 1. **Save format** — Format of the output. Choose at least one of "GIF"|"MP4"|"WEBP"|"PNG". Check "TXT" if you want infotext, which will live in the same directory as the output GIF. Infotext is also accessible via `stable-diffusion-webui/params.txt` and outputs in all formats. 1. You can optimize GIF with `gifsicle` (`apt install gifsicle` required, read [#91](https://github.com/continue-revolution/sd-webui-animatediff/pull/91) for more information) and/or `palette` (read [#104](https://github.com/continue-revolution/sd-webui-animatediff/pull/104) for more information). Go to `Settings/AnimateDiff` to enable them. 1. You can set quality and lossless for WEBP via `Settings/AnimateDiff`. Read [#233](https://github.com/continue-revolution/sd-webui-animatediff/pull/233) for more information. + 1. If you are using API, by adding "PNG" to `format`, you can save all frames to your file system without returning all the frames. If you want your API to return all frames, please add `Frame` to `format` list. 1. **Number of frames** — Choose whatever number you like. If you enter 0 (default): @@ -239,12 +240,6 @@ Batch number is NOT the same as batch size. In A1111 WebUI, batch number is abov We are currently developing approach to support batch size on WebUI in the near future. -## FAQ -1. Q: Will ADetailer be supported? - - A: I'm not planning to support ADetailer. However, I plan to refactor my [Segment Anything](https://github.com/continue-revolution/sd-webui-segment-anything) to achieve similar effects. - - ## Demo ### Basic Usage From 31f67ba424a8d6465785b618e6101f557acb63a0 Mon Sep 17 00:00:00 2001 From: zhangrc Date: Thu, 9 Nov 2023 00:55:42 +0800 Subject: [PATCH 42/54] add request_id in params and filename (#285) * animatediff_output.py filname add request_id * Update animatediff_ui.py AnimateDiffProcess add request_id param * allow request_id empty ,and remove from webui * move request_id at last in init method * remove request_id from webui --------- Co-authored-by: zhangruicheng --- scripts/animatediff_output.py | 3 +++ scripts/animatediff_ui.py | 5 ++++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/scripts/animatediff_output.py b/scripts/animatediff_output.py index e3649684..09896301 100644 --- a/scripts/animatediff_output.py +++ b/scripts/animatediff_output.py @@ -32,6 +32,9 @@ def output(self, p: StableDiffusionProcessing, res: Processed, params: AnimateDi seq = images.get_next_sequence_number(output_dir, "") filename = f"{seq:05}-{res.all_seeds[(i-res.index_of_first_image)]}" + if params.request_id : + filename = filename +f"-{params.request_id}" + video_path_prefix = output_dir / filename frame_list = self._add_reverse(params, frame_list) diff --git a/scripts/animatediff_ui.py b/scripts/animatediff_ui.py index 054b2ed4..bab6514d 100644 --- a/scripts/animatediff_ui.py +++ b/scripts/animatediff_ui.py @@ -44,6 +44,7 @@ def __init__( last_frame=None, latent_power_last=1, latent_scale_last=32, + request_id = '', ): self.model = model self.enable = enable @@ -64,10 +65,11 @@ def __init__( self.last_frame = last_frame self.latent_power_last = latent_power_last self.latent_scale_last = latent_scale_last + self.request_id = request_id def get_list(self, is_img2img: bool): - list_var = list(vars(self).values()) + list_var = list(vars(self).values())[:-1] if is_img2img: animatediff_i2ibatch.hack() else: @@ -88,6 +90,7 @@ def get_dict(self, is_img2img: bool): "overlap": self.overlap, "interp": self.interp, "interp_x": self.interp_x, + "request_id":self.request_id, } if motion_module.mm is not None and motion_module.mm.mm_hash is not None: infotext['mm_hash'] = motion_module.mm.mm_hash[:8] From e01cf4f63dfae878464408e182a7016557580096 Mon Sep 17 00:00:00 2001 From: Chengsong Zhang Date: Wed, 8 Nov 2023 11:11:40 -0600 Subject: [PATCH 43/54] fix device --- scripts/animatediff_infv2v.py | 4 ++-- tests/test_simple.py | 0 2 files changed, 2 insertions(+), 2 deletions(-) create mode 100644 tests/test_simple.py diff --git a/scripts/animatediff_infv2v.py b/scripts/animatediff_infv2v.py index aa547d39..5c31792a 100644 --- a/scripts/animatediff_infv2v.py +++ b/scripts/animatediff_infv2v.py @@ -111,12 +111,12 @@ def mm_cn_select(context: List[int]): if control.hint_cond.shape[0] > len(context): control.hint_cond_backup = control.hint_cond control.hint_cond = control.hint_cond[context] - control.hint_cond = control.hint_cond.to(device=shared.device) + control.hint_cond = control.hint_cond.to(device=devices.get_device_for("controlnet")) if control.hr_hint_cond is not None: if control.hr_hint_cond.shape[0] > len(context): control.hr_hint_cond_backup = control.hr_hint_cond control.hr_hint_cond = control.hr_hint_cond[context] - control.hr_hint_cond = control.hr_hint_cond.to(device=shared.device) + control.hr_hint_cond = control.hr_hint_cond.to(device=devices.get_device_for("controlnet")) # IPAdapter and Controlllite are always on CPU. elif control.control_model_type == ControlModelType.IPAdapter and control.control_model.image_emb.shape[0] > len(context): control.control_model.image_emb_backup = control.control_model.image_emb diff --git a/tests/test_simple.py b/tests/test_simple.py new file mode 100644 index 00000000..e69de29b From 11aa1d1cdcfcde6eabdf51bc3db2666a0f6bb025 Mon Sep 17 00:00:00 2001 From: Chengsong Zhang Date: Wed, 8 Nov 2023 11:15:10 -0600 Subject: [PATCH 44/54] request id --- scripts/animatediff_output.py | 5 ++--- scripts/animatediff_ui.py | 3 ++- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/scripts/animatediff_output.py b/scripts/animatediff_output.py index 09896301..5854cbcc 100644 --- a/scripts/animatediff_output.py +++ b/scripts/animatediff_output.py @@ -31,9 +31,8 @@ def output(self, p: StableDiffusionProcessing, res: Processed, params: AnimateDi frame_list = [image.copy() for image in res.images[i : i + params.video_length]] seq = images.get_next_sequence_number(output_dir, "") - filename = f"{seq:05}-{res.all_seeds[(i-res.index_of_first_image)]}" - if params.request_id : - filename = filename +f"-{params.request_id}" + filename_suffix = f"-{params.request_id}" if params.request_id else "" + filename = f"{seq:05}-{res.all_seeds[(i-res.index_of_first_image)]}{filename_suffix}" video_path_prefix = output_dir / filename diff --git a/scripts/animatediff_ui.py b/scripts/animatediff_ui.py index bab6514d..729f928d 100644 --- a/scripts/animatediff_ui.py +++ b/scripts/animatediff_ui.py @@ -90,8 +90,9 @@ def get_dict(self, is_img2img: bool): "overlap": self.overlap, "interp": self.interp, "interp_x": self.interp_x, - "request_id":self.request_id, } + if self.request_id: + infotext['request_id'] = self.request_id if motion_module.mm is not None and motion_module.mm.mm_hash is not None: infotext['mm_hash'] = motion_module.mm.mm_hash[:8] if is_img2img: From 4b47860708faa05cf5585027e552b31ae9bf2b7b Mon Sep 17 00:00:00 2001 From: Chengsong Zhang Date: Wed, 8 Nov 2023 11:48:07 -0600 Subject: [PATCH 45/54] add test --- .github/workflows/tests.yaml | 116 +++++++++++++++++++++++++++++++++++ tests/test_simple.py | 41 +++++++++++++ 2 files changed, 157 insertions(+) create mode 100644 .github/workflows/tests.yaml diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml new file mode 100644 index 00000000..1d346edd --- /dev/null +++ b/.github/workflows/tests.yaml @@ -0,0 +1,116 @@ +name: Run AnimateDiff generation with Motion LoRA & Prompt Travel on CPU + +on: + - push + - pull_request + +jobs: + build: + runs-on: ubuntu-latest + steps: + - name: Checkout A1111 + uses: actions/checkout@v3 + with: + repository: 'AUTOMATIC1111/stable-diffusion-webui' + path: 'stable-diffusion-webui' + - name: Checkout ControlNet + uses: actions/checkout@v3 + with: + repository: 'Mikubill/sd-webui-controlnet' + path: 'stable-diffusion-webui/extensions/sd-webui-controlnet' + - name: Checkout AnimateDiff + uses: actions/checkout@v3 + with: + repository: 'continue-revolution/sd-webui-animatediff' + path: 'stable-diffusion-webui/extensions/sd-webui-animatediff' + - name: Set up Python 3.11.4 + uses: actions/setup-python@v4 + with: + python-version: 3.11.4 + cache: pip + cache-dependency-path: | + **/requirements*txt + launch.py + - name: Install test dependencies + run: | + pip install wait-for-it + pip install -r requirements-test.txt + working-directory: stable-diffusion-webui + env: + PIP_DISABLE_PIP_VERSION_CHECK: "1" + PIP_PROGRESS_BAR: "off" + - name: Setup environment + run: python launch.py --skip-torch-cuda-test --exit + working-directory: stable-diffusion-webui + env: + PIP_DISABLE_PIP_VERSION_CHECK: "1" + PIP_PROGRESS_BAR: "off" + TORCH_INDEX_URL: https://download.pytorch.org/whl/cpu + WEBUI_LAUNCH_LIVE_OUTPUT: "1" + PYTHONUNBUFFERED: "1" + - name: Cache AnimateDiff models + uses: actions/cache@v3 + with: + path: stable-diffusion-webui/extensions/sd-webui-animatediff/model/ + key: animatediff-models-v1 + - name: Cache LoRA models + uses: actions/cache@v3 + with: + path: stable-diffusion-webui/models/Lora + key: lora-models-v1 + - name: Download AnimateDiff model for testing + run: | + if [ ! -f "extensions/sd-webui-animatediff/model/mm_sd_v15_v2.ckpt" ]; then + curl -Lo extensions/sd-webui-animatediff/model/mm_sd_v15_v2.ckpt https://huggingface.co/guoyww/animatediff/resolve/main/mm_sd_v15_v2.ckpt?download=true + fi + working-directory: stable-diffusion-webui + - name: Download LoRA model for testing + run: | + if [ ! -f "models/Lora/yoimiya.safetensors" ]; then + curl -Lo models/Lora/yoimiya.safetensors https://civitai.com/api/download/models/48374?type=Model&format=SafeTensor + fi + if [ ! -f "models/Lora/v2_lora_TiltDown.ckpt" ]; then + curl -Lo models/Lora/v2_lora_TiltDown.ckpt https://huggingface.co/guoyww/animatediff/resolve/main/v2_lora_TiltDown.ckpt?download=true + fi + - name: Start test server + run: > + python -m coverage run + --data-file=.coverage.server + launch.py + --skip-prepare-environment + --skip-torch-cuda-test + --test-server + --do-not-download-clip + --no-half + --disable-opt-split-attention + --use-cpu all + --api-server-stop + 2>&1 | tee output.txt & + working-directory: stable-diffusion-webui + - name: Run tests + run: | + wait-for-it --service 127.0.0.1:7860 -t 600 + python -m pytest -vv --junitxml=test/results.xml --cov ./extensions/sd-webui-animatediff --cov-report=xml --verify-base-url ./extensions/sd-webui-animatediff/tests + working-directory: stable-diffusion-webui + - name: Kill test server + if: always() + run: curl -vv -XPOST http://127.0.0.1:7860/sdapi/v1/server-stop && sleep 10 + - name: Show coverage + run: | + python -m coverage combine .coverage* + python -m coverage report -i + python -m coverage html -i + working-directory: stable-diffusion-webui + - name: Upload main app output + uses: actions/upload-artifact@v3 + if: always() + with: + name: output + path: output.txt + - name: Upload coverage HTML + uses: actions/upload-artifact@v3 + if: always() + with: + name: htmlcov + path: htmlcov + \ No newline at end of file diff --git a/tests/test_simple.py b/tests/test_simple.py index e69de29b..d6b58d1e 100644 --- a/tests/test_simple.py +++ b/tests/test_simple.py @@ -0,0 +1,41 @@ + +import pytest +import requests + + +@pytest.fixture() +def url_txt2img(base_url): + return f"{base_url}/sdapi/v1/txt2img" + + +@pytest.fixture() +def simple_txt2img_request(): + return { + "prompt": '1girl, yoimiya (genshin impact), origen, line, comet, wink, Masterpiece, BestQuality. UltraDetailed, , \n0: closed mouth\n8: open mouth,', + "negative_prompt": "(sketch, duplicate, ugly, huge eyes, text, logo, monochrome, worst face, (bad and mutated hands:1.3), (worst quality:2.0), (low quality:2.0), (blurry:2.0), horror, geometry, bad_prompt_v2, (bad hands), (missing fingers), multiple limbs, bad anatomy, (interlocked fingers:1.2), Ugly Fingers, (extra digit and hands and fingers and legs and arms:1.4), crown braid, ((2girl)), (deformed fingers:1.2), (long fingers:1.2),succubus wings,horn,succubus horn,succubus hairstyle, (bad-artist-anime), bad-artist, bad hand, grayscale, skin spots, acnes, skin blemishes", + "batch_size": 1, + "steps": 2, + "cfg_scale": 7, + "alwayson_scripts": { + 'AnimateDiff': { + 'args': [{ + 'enable': True, + 'batch_size': 2, + 'video_length': 4, + }] + } + } + } + + +def test_txt2img_simple_performed(url_txt2img, simple_txt2img_request): + ''' + This test checks the following: + - simple t2v generation + - prompt travel + - infinite context generator + - motion lora + ''' + response = requests.post(url_txt2img, json=simple_txt2img_request) + assert response.status_code == 200 + assert isinstance(response.json()['images'][0], str) From 251f873f741a6443eb0ba2eaa5450f7e436812d4 Mon Sep 17 00:00:00 2001 From: Chengsong Zhang Date: Wed, 8 Nov 2023 12:00:19 -0600 Subject: [PATCH 46/54] fix test --- .github/workflows/tests.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 1d346edd..107f96a0 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -72,6 +72,7 @@ jobs: if [ ! -f "models/Lora/v2_lora_TiltDown.ckpt" ]; then curl -Lo models/Lora/v2_lora_TiltDown.ckpt https://huggingface.co/guoyww/animatediff/resolve/main/v2_lora_TiltDown.ckpt?download=true fi + working-directory: stable-diffusion-webui - name: Start test server run: > python -m coverage run From 778cc40805da633e277d48384e64e2278ac1da40 Mon Sep 17 00:00:00 2001 From: Chengsong Zhang Date: Wed, 8 Nov 2023 12:12:18 -0600 Subject: [PATCH 47/54] fix test --- .github/workflows/tests.yaml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 107f96a0..a608fdf1 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -66,6 +66,9 @@ jobs: working-directory: stable-diffusion-webui - name: Download LoRA model for testing run: | + if [ ! -d "models/Lora" ]; then + mkdir models/Lora + fi if [ ! -f "models/Lora/yoimiya.safetensors" ]; then curl -Lo models/Lora/yoimiya.safetensors https://civitai.com/api/download/models/48374?type=Model&format=SafeTensor fi From 69719396251bfac439c2d1ee2e8ac36187cf57a9 Mon Sep 17 00:00:00 2001 From: Chengsong Zhang Date: Wed, 8 Nov 2023 12:29:17 -0600 Subject: [PATCH 48/54] fix test --- tests/test_simple.py | 28 +++++++++++++++++++++++++--- 1 file changed, 25 insertions(+), 3 deletions(-) diff --git a/tests/test_simple.py b/tests/test_simple.py index d6b58d1e..1a9b67bf 100644 --- a/tests/test_simple.py +++ b/tests/test_simple.py @@ -9,7 +9,7 @@ def url_txt2img(base_url): @pytest.fixture() -def simple_txt2img_request(): +def error_txt2img_request(): return { "prompt": '1girl, yoimiya (genshin impact), origen, line, comet, wink, Masterpiece, BestQuality. UltraDetailed, , \n0: closed mouth\n8: open mouth,', "negative_prompt": "(sketch, duplicate, ugly, huge eyes, text, logo, monochrome, worst face, (bad and mutated hands:1.3), (worst quality:2.0), (low quality:2.0), (blurry:2.0), horror, geometry, bad_prompt_v2, (bad hands), (missing fingers), multiple limbs, bad anatomy, (interlocked fingers:1.2), Ugly Fingers, (extra digit and hands and fingers and legs and arms:1.4), crown braid, ((2girl)), (deformed fingers:1.2), (long fingers:1.2),succubus wings,horn,succubus horn,succubus hairstyle, (bad-artist-anime), bad-artist, bad hand, grayscale, skin spots, acnes, skin blemishes", @@ -28,14 +28,36 @@ def simple_txt2img_request(): } -def test_txt2img_simple_performed(url_txt2img, simple_txt2img_request): +@pytest.fixture() +def correct_txt2img_request(): + return { + "prompt": '1girl, yoimiya (genshin impact), origen, line, comet, wink, Masterpiece, BestQuality. UltraDetailed, , \n0: closed mouth\n2: open mouth,', + "negative_prompt": "(sketch, duplicate, ugly, huge eyes, text, logo, monochrome, worst face, (bad and mutated hands:1.3), (worst quality:2.0), (low quality:2.0), (blurry:2.0), horror, geometry, bad_prompt_v2, (bad hands), (missing fingers), multiple limbs, bad anatomy, (interlocked fingers:1.2), Ugly Fingers, (extra digit and hands and fingers and legs and arms:1.4), crown braid, ((2girl)), (deformed fingers:1.2), (long fingers:1.2),succubus wings,horn,succubus horn,succubus hairstyle, (bad-artist-anime), bad-artist, bad hand, grayscale, skin spots, acnes, skin blemishes", + "batch_size": 1, + "steps": 2, + "cfg_scale": 7, + "alwayson_scripts": { + 'AnimateDiff': { + 'args': [{ + 'enable': True, + 'batch_size': 2, + 'video_length': 4, + }] + } + } + } + + +def test_txt2img_simple_performed(url_txt2img, error_txt2img_request, correct_txt2img_request): ''' This test checks the following: - simple t2v generation - prompt travel - infinite context generator - motion lora + - error recovery ''' - response = requests.post(url_txt2img, json=simple_txt2img_request) + assert requests.post(url_txt2img, json=error_txt2img_request).status_code == 500 + response = requests.post(url_txt2img, json=correct_txt2img_request) assert response.status_code == 200 assert isinstance(response.json()['images'][0], str) From c673323eece768bf654ff8bd1b6d75562c03d50a Mon Sep 17 00:00:00 2001 From: Chengsong Zhang Date: Wed, 8 Nov 2023 12:40:05 -0600 Subject: [PATCH 49/54] cheaper test --- tests/test_simple.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/test_simple.py b/tests/test_simple.py index 1a9b67bf..94ce1eef 100644 --- a/tests/test_simple.py +++ b/tests/test_simple.py @@ -14,14 +14,14 @@ def error_txt2img_request(): "prompt": '1girl, yoimiya (genshin impact), origen, line, comet, wink, Masterpiece, BestQuality. UltraDetailed, , \n0: closed mouth\n8: open mouth,', "negative_prompt": "(sketch, duplicate, ugly, huge eyes, text, logo, monochrome, worst face, (bad and mutated hands:1.3), (worst quality:2.0), (low quality:2.0), (blurry:2.0), horror, geometry, bad_prompt_v2, (bad hands), (missing fingers), multiple limbs, bad anatomy, (interlocked fingers:1.2), Ugly Fingers, (extra digit and hands and fingers and legs and arms:1.4), crown braid, ((2girl)), (deformed fingers:1.2), (long fingers:1.2),succubus wings,horn,succubus horn,succubus hairstyle, (bad-artist-anime), bad-artist, bad hand, grayscale, skin spots, acnes, skin blemishes", "batch_size": 1, - "steps": 2, + "steps": 1, "cfg_scale": 7, "alwayson_scripts": { 'AnimateDiff': { 'args': [{ 'enable': True, - 'batch_size': 2, - 'video_length': 4, + 'batch_size': 1, + 'video_length': 2, }] } } @@ -31,17 +31,17 @@ def error_txt2img_request(): @pytest.fixture() def correct_txt2img_request(): return { - "prompt": '1girl, yoimiya (genshin impact), origen, line, comet, wink, Masterpiece, BestQuality. UltraDetailed, , \n0: closed mouth\n2: open mouth,', + "prompt": '1girl, yoimiya (genshin impact), origen, line, comet, wink, Masterpiece, BestQuality. UltraDetailed, , \n0: closed mouth\n1: open mouth,', "negative_prompt": "(sketch, duplicate, ugly, huge eyes, text, logo, monochrome, worst face, (bad and mutated hands:1.3), (worst quality:2.0), (low quality:2.0), (blurry:2.0), horror, geometry, bad_prompt_v2, (bad hands), (missing fingers), multiple limbs, bad anatomy, (interlocked fingers:1.2), Ugly Fingers, (extra digit and hands and fingers and legs and arms:1.4), crown braid, ((2girl)), (deformed fingers:1.2), (long fingers:1.2),succubus wings,horn,succubus horn,succubus hairstyle, (bad-artist-anime), bad-artist, bad hand, grayscale, skin spots, acnes, skin blemishes", "batch_size": 1, - "steps": 2, + "steps": 1, "cfg_scale": 7, "alwayson_scripts": { 'AnimateDiff': { 'args': [{ 'enable': True, - 'batch_size': 2, - 'video_length': 4, + 'batch_size': 1, + 'video_length': 2, }] } } From 8ec78cfa94810680141648bb0dd9798af1ea288a Mon Sep 17 00:00:00 2001 From: Chengsong Zhang Date: Wed, 8 Nov 2023 12:50:07 -0600 Subject: [PATCH 50/54] cheaper test --- .github/workflows/tests.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index a608fdf1..35066cce 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -85,7 +85,7 @@ jobs: --skip-torch-cuda-test --test-server --do-not-download-clip - --no-half + --no-half-vae --disable-opt-split-attention --use-cpu all --api-server-stop From 7fa9e5e3b55245c995fc304d548d7a1a7b96e3cf Mon Sep 17 00:00:00 2001 From: Chengsong Zhang Date: Wed, 8 Nov 2023 12:58:13 -0600 Subject: [PATCH 51/54] readme --- .github/workflows/tests.yaml | 2 +- README.md | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 35066cce..a608fdf1 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -85,7 +85,7 @@ jobs: --skip-torch-cuda-test --test-server --do-not-download-clip - --no-half-vae + --no-half --disable-opt-split-attention --use-cpu all --api-server-stop diff --git a/README.md b/README.md index 54fac07a..ed9c4db9 100644 --- a/README.md +++ b/README.md @@ -53,7 +53,7 @@ You might also be interested in another extension I created: [Segment Anything f - `2023/10/21`: [v1.9.4](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.9.4): Save prompt travel to output images, `Reverse` merged to `Closed loop` (See [WebUI Parameters](#webui-parameters)), remove `TimestepEmbedSequential` hijack, remove `hints.js`, better explanation of several context-related parameters. - `2023/10/25`: [v1.10.0](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.10.0): Support img2img batch. You need ControlNet installed to make it work properly (you do not need to enable ControlNet). See [ControlNet V2V](#controlnet-v2v) for more information. - `2023/10/29`: [v1.11.0](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.11.0): Support [HotShot-XL](https://github.com/hotshotco/Hotshot-XL) for SDXL. See [HotShot-XL](#hotshot-xl) for more information. -- `2023/11/06`: [v1.11.1](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.11.1): optimize VRAM to support any number of control images for ControlNet V2V, patch [encode_pil_to_base64](https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/master/modules/api/api.py#L104-L133) to support api return a video, save frames to `AnimateDIff/yy-mm-dd/`, recover from assertion error without restart. +- `2023/11/06`: [v1.11.1](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.11.1): optimize VRAM to support any number of control images for ControlNet V2V, patch [encode_pil_to_base64](https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/master/modules/api/api.py#L104-L133) to support api return a video, save frames to `AnimateDIff/yy-mm-dd/`, recover from assertion error without restart, test case, optional [request id](#api) for API. For future update plan, please query [here](https://github.com/continue-revolution/sd-webui-animatediff/pull/224). @@ -94,7 +94,8 @@ It is quite similar to the way you use ControlNet. API will return a video in ba 'latent_scale': 32, # Latent scale 'last_frame': None, # Optional last frame 'latent_power_last': 1, # Optional latent power for last frame - 'latent_scale_last': 32 # Optional latent scale for last frame + 'latent_scale_last': 32,# Optional latent scale for last frame + 'request_id': '' # Optional request id. If provided, outputs will have request id as filename suffix } ] } From 0f912f4b6f13daa7f4c8119b98e5588b2ab2c7fb Mon Sep 17 00:00:00 2001 From: Chengsong Zhang Date: Wed, 8 Nov 2023 13:00:51 -0600 Subject: [PATCH 52/54] run one test at one time --- .github/workflows/tests.yaml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index a608fdf1..c0f3481d 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -1,8 +1,7 @@ name: Run AnimateDiff generation with Motion LoRA & Prompt Travel on CPU on: - - push - - pull_request + push: {} # Remove the branch restriction to trigger the workflow for any branch jobs: build: From 8ccf591756ab837c555a828727bc6c07f34cd8b6 Mon Sep 17 00:00:00 2001 From: Chengsong Zhang Date: Wed, 8 Nov 2023 13:08:25 -0600 Subject: [PATCH 53/54] readme --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index ed9c4db9..fba426b0 100644 --- a/README.md +++ b/README.md @@ -53,7 +53,7 @@ You might also be interested in another extension I created: [Segment Anything f - `2023/10/21`: [v1.9.4](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.9.4): Save prompt travel to output images, `Reverse` merged to `Closed loop` (See [WebUI Parameters](#webui-parameters)), remove `TimestepEmbedSequential` hijack, remove `hints.js`, better explanation of several context-related parameters. - `2023/10/25`: [v1.10.0](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.10.0): Support img2img batch. You need ControlNet installed to make it work properly (you do not need to enable ControlNet). See [ControlNet V2V](#controlnet-v2v) for more information. - `2023/10/29`: [v1.11.0](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.11.0): Support [HotShot-XL](https://github.com/hotshotco/Hotshot-XL) for SDXL. See [HotShot-XL](#hotshot-xl) for more information. -- `2023/11/06`: [v1.11.1](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.11.1): optimize VRAM to support any number of control images for ControlNet V2V, patch [encode_pil_to_base64](https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/master/modules/api/api.py#L104-L133) to support api return a video, save frames to `AnimateDIff/yy-mm-dd/`, recover from assertion error without restart, test case, optional [request id](#api) for API. +- `2023/11/06`: [v1.11.1](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.11.1): optimize VRAM for ControlNet V2V, patch [encode_pil_to_base64](https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/master/modules/api/api.py#L104-L133) for api return a video, save frames to `AnimateDIff/yy-mm-dd/`, recover from assertion error, test case, optional [request id](#api) for API. For future update plan, please query [here](https://github.com/continue-revolution/sd-webui-animatediff/pull/224). From 665c6153033e99fcdd2fb1787bc480fc0f40f164 Mon Sep 17 00:00:00 2001 From: Chengsong Zhang Date: Wed, 8 Nov 2023 13:10:51 -0600 Subject: [PATCH 54/54] use ubuntu 20.04 --- .github/workflows/tests.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index c0f3481d..386c9a1f 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -5,7 +5,7 @@ on: jobs: build: - runs-on: ubuntu-latest + runs-on: ubuntu-20.04 steps: - name: Checkout A1111 uses: actions/checkout@v3