diff --git a/scripts/animatediff_mm.py b/scripts/animatediff_mm.py index 230d21b1..6f61ea02 100644 --- a/scripts/animatediff_mm.py +++ b/scripts/animatediff_mm.py @@ -4,14 +4,15 @@ import torch from einops import rearrange -from ldm.modules.attention import SpatialTransformer -from ldm.modules.diffusionmodules.openaimodel import (TimestepBlock, - TimestepEmbedSequential) +# from ldm.modules.attention import SpatialTransformer +# from ldm.modules.diffusionmodules.openaimodel import (TimestepBlock, +# TimestepEmbedSequential) from ldm.modules.diffusionmodules.util import GroupNorm32 from modules import hashes, shared, sd_models from modules.devices import cpu, device, torch_gc -from motion_module import MotionWrapper, VanillaTemporalModule +from motion_module import MotionWrapper +# from motion_module import VanillaTemporalModule from scripts.animatediff_logger import logger_animatediff as logger @@ -24,7 +25,7 @@ def __init__(self): self.prev_alpha_cumprod = None self.prev_alpha_cumprod_prev = None self.gn32_original_forward = None - self.tes_original_forward = None + # self.tes_original_forward = None def set_script_dir(self, script_dir): @@ -75,20 +76,20 @@ def inject(self, sd_model, model_name="mm_sd_v15.ckpt"): unet = sd_model.model.diffusion_model self._load(model_name) self.gn32_original_forward = GroupNorm32.forward - self.tes_original_forward = TimestepEmbedSequential.forward gn32_original_forward = self.gn32_original_forward - - def mm_tes_forward(self, x, emb, context=None): - for layer in self: - if isinstance(layer, TimestepBlock): - x = layer(x, emb) - elif isinstance(layer, (SpatialTransformer, VanillaTemporalModule)): - x = layer(x, context) - else: - x = layer(x) - return x - - TimestepEmbedSequential.forward = mm_tes_forward + # self.tes_original_forward = TimestepEmbedSequential.forward + + # def mm_tes_forward(self, x, emb, context=None): + # for layer in self: + # if isinstance(layer, TimestepBlock): + # x = layer(x, emb) + # elif isinstance(layer, (SpatialTransformer, VanillaTemporalModule)): + # x = layer(x, context) + # else: + # x = layer(x) + # return x + + # TimestepEmbedSequential.forward = mm_tes_forward if self.mm.using_v2: logger.info(f"Injecting motion module {model_name} into SD1.5 UNet middle block.") unet.middle_block.insert(-1, self.mm.mid_block.motion_modules[0]) @@ -145,7 +146,7 @@ def restore(self, sd_model): else: logger.info(f"Restoring GroupNorm32 forward function.") GroupNorm32.forward = self.gn32_original_forward - TimestepEmbedSequential.forward = self.tes_original_forward + # TimestepEmbedSequential.forward = self.tes_original_forward logger.info(f"Removal finished.") if shared.cmd_opts.lowvram: self.unload()