-
Notifications
You must be signed in to change notification settings - Fork 5.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* transformer * make style & make fix-copies * transformer * add transformer tests * 80% vae * make style * make fix-copies * fix * undo cogvideox changes * update * update * match vae * add docs * t2v pipeline working; scheduler needs to be checked * docs * add pipeline test * update * update * make fix-copies * Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * update * copy t2v to i2v pipeline * update * apply review suggestions * update * make style * remove framewise encoding/decoding * pack/unpack latents * image2video * update * make fix-copies * update * update * rope scale fix * debug layerwise code * remove debug * Apply suggestions from code review Co-authored-by: YiYi Xu <yixu310@gmail.com> * propagate precision changes to i2v pipeline * remove downcast * address review comments * fix comment * address review comments * [Single File] LTX support for loading original weights (#10135) * from original file mixin for ltx * undo config mapping fn changes * update * add single file to pipelines * update docs * Update src/diffusers/models/autoencoders/autoencoder_kl_ltx.py * Update src/diffusers/models/autoencoders/autoencoder_kl_ltx.py * rename classes based on ltx review * point to original repository for inference * make style * resolve conflicts correctly --------- Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> Co-authored-by: YiYi Xu <yixu310@gmail.com>
- Loading branch information
1 parent
8170dc3
commit 96c376a
Showing
26 changed files
with
4,439 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
<!-- Copyright 2024 The HuggingFace Team. All rights reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | ||
the License. You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | ||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | ||
specific language governing permissions and limitations under the License. --> | ||
|
||
# AutoencoderKLLTXVideo | ||
|
||
The 3D variational autoencoder (VAE) model with KL loss used in [LTX](https://huggingface.co/Lightricks/LTX-Video) was introduced by Lightricks. | ||
|
||
The model can be loaded with the following code snippet. | ||
|
||
```python | ||
from diffusers import AutoencoderKLLTXVideo | ||
|
||
vae = AutoencoderKLLTXVideo.from_pretrained("TODO/TODO", subfolder="vae", torch_dtype=torch.float32).to("cuda") | ||
``` | ||
|
||
## AutoencoderKLLTXVideo | ||
|
||
[[autodoc]] AutoencoderKLLTXVideo | ||
- decode | ||
- encode | ||
- all | ||
|
||
## AutoencoderKLOutput | ||
|
||
[[autodoc]] models.autoencoders.autoencoder_kl.AutoencoderKLOutput | ||
|
||
## DecoderOutput | ||
|
||
[[autodoc]] models.autoencoders.vae.DecoderOutput |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
<!-- Copyright 2024 The HuggingFace Team. All rights reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | ||
the License. You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | ||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | ||
specific language governing permissions and limitations under the License. --> | ||
|
||
# LTXVideoTransformer3DModel | ||
|
||
A Diffusion Transformer model for 3D data from [LTX](https://huggingface.co/Lightricks/LTX-Video) was introduced by Lightricks. | ||
|
||
The model can be loaded with the following code snippet. | ||
|
||
```python | ||
from diffusers import LTXVideoTransformer3DModel | ||
|
||
transformer = LTXVideoTransformer3DModel.from_pretrained("TODO/TODO", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda") | ||
``` | ||
|
||
## LTXVideoTransformer3DModel | ||
|
||
[[autodoc]] LTXVideoTransformer3DModel | ||
|
||
## Transformer2DModelOutput | ||
|
||
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
<!-- Copyright 2024 The HuggingFace Team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. --> | ||
|
||
# LTX | ||
|
||
[LTX Video](https://huggingface.co/Lightricks/LTX-Video) is the first DiT-based video generation model capable of generating high-quality videos in real-time. It produces 24 FPS videos at a 768x512 resolution faster than they can be watched. Trained on a large-scale dataset of diverse videos, the model generates high-resolution videos with realistic and varied content. We provide a model for both text-to-video as well as image + text-to-video usecases. | ||
|
||
<Tip> | ||
|
||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines. | ||
|
||
</Tip> | ||
|
||
## Loading Single Files | ||
|
||
Loading the original LTX Video checkpoints is also possible with [`~ModelMixin.from_single_file`]. | ||
|
||
```python | ||
import torch | ||
from diffusers import AutoencoderKLLTXVideo, LTXImageToVideoPipeline, LTXVideoTransformer3DModel | ||
|
||
single_file_url = "https://huggingface.co/Lightricks/LTX-Video/ltx-video-2b-v0.9.safetensors" | ||
transformer = LTXVideoTransformer3DModel.from_single_file(single_file_url, torch_dtype=torch.bfloat16) | ||
vae = AutoencoderKLLTXVideo.from_single_file(single_file_url, torch_dtype=torch.bfloat16) | ||
pipe = LTXImageToVideoPipeline.from_pretrained("Lightricks/LTX-Video", transformer=transformer, vae=vae, torch_dtype=torch.bfloat16) | ||
|
||
# ... inference code ... | ||
``` | ||
|
||
Alternatively, the pipeline can be used to load the weights with [~FromSingleFileMixin.from_single_file`]. | ||
|
||
```python | ||
import torch | ||
from diffusers import LTXImageToVideoPipeline | ||
from transformers import T5EncoderModel, T5Tokenizer | ||
|
||
single_file_url = "https://huggingface.co/Lightricks/LTX-Video/ltx-video-2b-v0.9.safetensors" | ||
text_encoder = T5EncoderModel.from_pretrained("Lightricks/LTX-Video", subfolder="text_encoder", torch_dtype=torch.bfloat16) | ||
tokenizer = T5Tokenizer.from_pretrained("Lightricks/LTX-Video", subfolder="tokenizer", torch_dtype=torch.bfloat16) | ||
pipe = LTXImageToVideoPipeline.from_single_file(single_file_url, text_encoder=text_encoder, tokenizer=tokenizer, torch_dtype=torch.bfloat16) | ||
``` | ||
|
||
## LTXPipeline | ||
|
||
[[autodoc]] LTXPipeline | ||
- all | ||
- __call__ | ||
|
||
## LTXImageToVideoPipeline | ||
|
||
[[autodoc]] LTXImageToVideoPipeline | ||
- all | ||
- __call__ | ||
|
||
## LTXPipelineOutput | ||
|
||
[[autodoc]] pipelines.ltx.pipeline_output.LTXPipelineOutput |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,209 @@ | ||
import argparse | ||
from typing import Any, Dict | ||
|
||
import torch | ||
from safetensors.torch import load_file | ||
from transformers import T5EncoderModel, T5Tokenizer | ||
|
||
from diffusers import AutoencoderKLLTXVideo, FlowMatchEulerDiscreteScheduler, LTXPipeline, LTXVideoTransformer3DModel | ||
|
||
|
||
def remove_keys_(key: str, state_dict: Dict[str, Any]): | ||
state_dict.pop(key) | ||
|
||
|
||
TOKENIZER_MAX_LENGTH = 128 | ||
|
||
TRANSFORMER_KEYS_RENAME_DICT = { | ||
"patchify_proj": "proj_in", | ||
"adaln_single": "time_embed", | ||
"q_norm": "norm_q", | ||
"k_norm": "norm_k", | ||
} | ||
|
||
TRANSFORMER_SPECIAL_KEYS_REMAP = {} | ||
|
||
VAE_KEYS_RENAME_DICT = { | ||
# decoder | ||
"up_blocks.0": "mid_block", | ||
"up_blocks.1": "up_blocks.0", | ||
"up_blocks.2": "up_blocks.1.upsamplers.0", | ||
"up_blocks.3": "up_blocks.1", | ||
"up_blocks.4": "up_blocks.2.conv_in", | ||
"up_blocks.5": "up_blocks.2.upsamplers.0", | ||
"up_blocks.6": "up_blocks.2", | ||
"up_blocks.7": "up_blocks.3.conv_in", | ||
"up_blocks.8": "up_blocks.3.upsamplers.0", | ||
"up_blocks.9": "up_blocks.3", | ||
# encoder | ||
"down_blocks.0": "down_blocks.0", | ||
"down_blocks.1": "down_blocks.0.downsamplers.0", | ||
"down_blocks.2": "down_blocks.0.conv_out", | ||
"down_blocks.3": "down_blocks.1", | ||
"down_blocks.4": "down_blocks.1.downsamplers.0", | ||
"down_blocks.5": "down_blocks.1.conv_out", | ||
"down_blocks.6": "down_blocks.2", | ||
"down_blocks.7": "down_blocks.2.downsamplers.0", | ||
"down_blocks.8": "down_blocks.3", | ||
"down_blocks.9": "mid_block", | ||
# common | ||
"conv_shortcut": "conv_shortcut.conv", | ||
"res_blocks": "resnets", | ||
"norm3.norm": "norm3", | ||
"per_channel_statistics.mean-of-means": "latents_mean", | ||
"per_channel_statistics.std-of-means": "latents_std", | ||
} | ||
|
||
VAE_SPECIAL_KEYS_REMAP = { | ||
"per_channel_statistics.channel": remove_keys_, | ||
"per_channel_statistics.mean-of-means": remove_keys_, | ||
"per_channel_statistics.mean-of-stds": remove_keys_, | ||
} | ||
|
||
|
||
def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]: | ||
state_dict = saved_dict | ||
if "model" in saved_dict.keys(): | ||
state_dict = state_dict["model"] | ||
if "module" in saved_dict.keys(): | ||
state_dict = state_dict["module"] | ||
if "state_dict" in saved_dict.keys(): | ||
state_dict = state_dict["state_dict"] | ||
return state_dict | ||
|
||
|
||
def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]: | ||
state_dict[new_key] = state_dict.pop(old_key) | ||
|
||
|
||
def convert_transformer( | ||
ckpt_path: str, | ||
dtype: torch.dtype, | ||
): | ||
PREFIX_KEY = "" | ||
|
||
original_state_dict = get_state_dict(load_file(ckpt_path)) | ||
transformer = LTXVideoTransformer3DModel().to(dtype=dtype) | ||
|
||
for key in list(original_state_dict.keys()): | ||
new_key = key[len(PREFIX_KEY) :] | ||
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items(): | ||
new_key = new_key.replace(replace_key, rename_key) | ||
update_state_dict_inplace(original_state_dict, key, new_key) | ||
|
||
for key in list(original_state_dict.keys()): | ||
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items(): | ||
if special_key not in key: | ||
continue | ||
handler_fn_inplace(key, original_state_dict) | ||
|
||
transformer.load_state_dict(original_state_dict, strict=True) | ||
return transformer | ||
|
||
|
||
def convert_vae(ckpt_path: str, dtype: torch.dtype): | ||
original_state_dict = get_state_dict(load_file(ckpt_path)) | ||
vae = AutoencoderKLLTXVideo().to(dtype=dtype) | ||
|
||
for key in list(original_state_dict.keys()): | ||
new_key = key[:] | ||
for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items(): | ||
new_key = new_key.replace(replace_key, rename_key) | ||
update_state_dict_inplace(original_state_dict, key, new_key) | ||
|
||
for key in list(original_state_dict.keys()): | ||
for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items(): | ||
if special_key not in key: | ||
continue | ||
handler_fn_inplace(key, original_state_dict) | ||
|
||
vae.load_state_dict(original_state_dict, strict=True) | ||
return vae | ||
|
||
|
||
def get_args(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint" | ||
) | ||
parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint") | ||
parser.add_argument( | ||
"--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory" | ||
) | ||
parser.add_argument( | ||
"--typecast_text_encoder", | ||
action="store_true", | ||
default=False, | ||
help="Whether or not to apply fp16/bf16 precision to text_encoder", | ||
) | ||
parser.add_argument("--save_pipeline", action="store_true") | ||
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved") | ||
parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.") | ||
return parser.parse_args() | ||
|
||
|
||
DTYPE_MAPPING = { | ||
"fp32": torch.float32, | ||
"fp16": torch.float16, | ||
"bf16": torch.bfloat16, | ||
} | ||
|
||
VARIANT_MAPPING = { | ||
"fp32": None, | ||
"fp16": "fp16", | ||
"bf16": "bf16", | ||
} | ||
|
||
|
||
if __name__ == "__main__": | ||
args = get_args() | ||
|
||
transformer = None | ||
dtype = DTYPE_MAPPING[args.dtype] | ||
variant = VARIANT_MAPPING[args.dtype] | ||
|
||
if args.save_pipeline: | ||
assert args.transformer_ckpt_path is not None and args.vae_ckpt_path is not None | ||
|
||
if args.transformer_ckpt_path is not None: | ||
transformer: LTXVideoTransformer3DModel = convert_transformer(args.transformer_ckpt_path, dtype) | ||
if not args.save_pipeline: | ||
transformer.save_pretrained( | ||
args.output_path, safe_serialization=True, max_shard_size="5GB", variant=variant | ||
) | ||
|
||
if args.vae_ckpt_path is not None: | ||
vae: AutoencoderKLLTXVideo = convert_vae(args.vae_ckpt_path, dtype) | ||
if not args.save_pipeline: | ||
vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", variant=variant) | ||
|
||
if args.save_pipeline: | ||
text_encoder_id = "google/t5-v1_1-xxl" | ||
tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH) | ||
text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir) | ||
|
||
if args.typecast_text_encoder: | ||
text_encoder = text_encoder.to(dtype=dtype) | ||
|
||
# Apparently, the conversion does not work anymore without this :shrug: | ||
for param in text_encoder.parameters(): | ||
param.data = param.data.contiguous() | ||
|
||
scheduler = FlowMatchEulerDiscreteScheduler( | ||
use_dynamic_shifting=True, | ||
base_shift=0.95, | ||
max_shift=2.05, | ||
base_image_seq_len=1024, | ||
max_image_seq_len=4096, | ||
shift_terminal=0.1, | ||
) | ||
|
||
pipe = LTXPipeline( | ||
scheduler=scheduler, | ||
vae=vae, | ||
text_encoder=text_encoder, | ||
tokenizer=tokenizer, | ||
transformer=transformer, | ||
) | ||
|
||
pipe.save_pretrained(args.output_path, safe_serialization=True, variant=variant, max_shard_size="5GB") |
Oops, something went wrong.