Skip to content

Commit

Permalink
remove timestep hack
Browse files Browse the repository at this point in the history
  • Loading branch information
continue-revolution committed Oct 21, 2023
1 parent 33db66c commit 1716e29
Showing 1 changed file with 20 additions and 19 deletions.
39 changes: 20 additions & 19 deletions scripts/animatediff_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@

import torch
from einops import rearrange
from ldm.modules.attention import SpatialTransformer
from ldm.modules.diffusionmodules.openaimodel import (TimestepBlock,
TimestepEmbedSequential)
# from ldm.modules.attention import SpatialTransformer
# from ldm.modules.diffusionmodules.openaimodel import (TimestepBlock,
# TimestepEmbedSequential)
from ldm.modules.diffusionmodules.util import GroupNorm32
from modules import hashes, shared, sd_models
from modules.devices import cpu, device, torch_gc

from motion_module import MotionWrapper, VanillaTemporalModule
from motion_module import MotionWrapper
# from motion_module import VanillaTemporalModule
from scripts.animatediff_logger import logger_animatediff as logger


Expand All @@ -24,7 +25,7 @@ def __init__(self):
self.prev_alpha_cumprod = None
self.prev_alpha_cumprod_prev = None
self.gn32_original_forward = None
self.tes_original_forward = None
# self.tes_original_forward = None


def set_script_dir(self, script_dir):
Expand Down Expand Up @@ -75,20 +76,20 @@ def inject(self, sd_model, model_name="mm_sd_v15.ckpt"):
unet = sd_model.model.diffusion_model
self._load(model_name)
self.gn32_original_forward = GroupNorm32.forward
self.tes_original_forward = TimestepEmbedSequential.forward
gn32_original_forward = self.gn32_original_forward

def mm_tes_forward(self, x, emb, context=None):
for layer in self:
if isinstance(layer, TimestepBlock):
x = layer(x, emb)
elif isinstance(layer, (SpatialTransformer, VanillaTemporalModule)):
x = layer(x, context)
else:
x = layer(x)
return x

TimestepEmbedSequential.forward = mm_tes_forward
# self.tes_original_forward = TimestepEmbedSequential.forward

# def mm_tes_forward(self, x, emb, context=None):
# for layer in self:
# if isinstance(layer, TimestepBlock):
# x = layer(x, emb)
# elif isinstance(layer, (SpatialTransformer, VanillaTemporalModule)):
# x = layer(x, context)
# else:
# x = layer(x)
# return x

# TimestepEmbedSequential.forward = mm_tes_forward
if self.mm.using_v2:
logger.info(f"Injecting motion module {model_name} into SD1.5 UNet middle block.")
unet.middle_block.insert(-1, self.mm.mid_block.motion_modules[0])
Expand Down Expand Up @@ -145,7 +146,7 @@ def restore(self, sd_model):
else:
logger.info(f"Restoring GroupNorm32 forward function.")
GroupNorm32.forward = self.gn32_original_forward
TimestepEmbedSequential.forward = self.tes_original_forward
# TimestepEmbedSequential.forward = self.tes_original_forward
logger.info(f"Removal finished.")
if shared.cmd_opts.lowvram:
self.unload()
Expand Down

0 comments on commit 1716e29

Please sign in to comment.