Skip to content

Commit

Permalink
[core] LTX Video (#10021)
Browse files Browse the repository at this point in the history
* transformer

* make style & make fix-copies

* transformer

* add transformer tests

* 80% vae

* make style

* make fix-copies

* fix

* undo cogvideox changes

* update

* update

* match vae

* add docs

* t2v pipeline working; scheduler needs to be checked

* docs

* add pipeline test

* update

* update

* make fix-copies

* Apply suggestions from code review

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* update

* copy t2v to i2v pipeline

* update

* apply review suggestions

* update

* make style

* remove framewise encoding/decoding

* pack/unpack latents

* image2video

* update

* make fix-copies

* update

* update

* rope scale fix

* debug layerwise code

* remove debug

* Apply suggestions from code review

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* propagate precision changes to i2v pipeline

* remove downcast

* address review comments

* fix comment

* address review comments

* [Single File] LTX support for loading original weights (#10135)

* from original file mixin for ltx

* undo config mapping fn changes

* update

* add single file to pipelines

* update docs

* Update src/diffusers/models/autoencoders/autoencoder_kl_ltx.py

* Update src/diffusers/models/autoencoders/autoencoder_kl_ltx.py

* rename classes based on ltx review

* point to original repository for inference

* make style

* resolve conflicts correctly

---------

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
  • Loading branch information
3 people authored Dec 12, 2024
1 parent 8170dc3 commit 96c376a
Show file tree
Hide file tree
Showing 26 changed files with 4,439 additions and 1 deletion.
6 changes: 6 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,8 @@
title: LatteTransformer3DModel
- local: api/models/lumina_nextdit2d
title: LuminaNextDiT2DModel
- local: api/models/ltx_video_transformer3d
title: LTXVideoTransformer3DModel
- local: api/models/mochi_transformer3d
title: MochiTransformer3DModel
- local: api/models/pixart_transformer2d
Expand Down Expand Up @@ -312,6 +314,8 @@
title: AutoencoderKLAllegro
- local: api/models/autoencoderkl_cogvideox
title: AutoencoderKLCogVideoX
- local: api/models/autoencoderkl_ltx_video
title: AutoencoderKLLTXVideo
- local: api/models/autoencoderkl_mochi
title: AutoencoderKLMochi
- local: api/models/asymmetricautoencoderkl
Expand Down Expand Up @@ -408,6 +412,8 @@
title: Latte
- local: api/pipelines/ledits_pp
title: LEDITS++
- local: api/pipelines/ltx_video
title: LTX
- local: api/pipelines/lumina
title: Lumina-T2X
- local: api/pipelines/marigold
Expand Down
37 changes: 37 additions & 0 deletions docs/source/en/api/models/autoencoderkl_ltx_video.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->

# AutoencoderKLLTXVideo

The 3D variational autoencoder (VAE) model with KL loss used in [LTX](https://huggingface.co/Lightricks/LTX-Video) was introduced by Lightricks.

The model can be loaded with the following code snippet.

```python
from diffusers import AutoencoderKLLTXVideo

vae = AutoencoderKLLTXVideo.from_pretrained("TODO/TODO", subfolder="vae", torch_dtype=torch.float32).to("cuda")
```

## AutoencoderKLLTXVideo

[[autodoc]] AutoencoderKLLTXVideo
- decode
- encode
- all

## AutoencoderKLOutput

[[autodoc]] models.autoencoders.autoencoder_kl.AutoencoderKLOutput

## DecoderOutput

[[autodoc]] models.autoencoders.vae.DecoderOutput
30 changes: 30 additions & 0 deletions docs/source/en/api/models/ltx_video_transformer3d.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->

# LTXVideoTransformer3DModel

A Diffusion Transformer model for 3D data from [LTX](https://huggingface.co/Lightricks/LTX-Video) was introduced by Lightricks.

The model can be loaded with the following code snippet.

```python
from diffusers import LTXVideoTransformer3DModel

transformer = LTXVideoTransformer3DModel.from_pretrained("TODO/TODO", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda")
```

## LTXVideoTransformer3DModel

[[autodoc]] LTXVideoTransformer3DModel

## Transformer2DModelOutput

[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
68 changes: 68 additions & 0 deletions docs/source/en/api/pipelines/ltx_video.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License. -->

# LTX

[LTX Video](https://huggingface.co/Lightricks/LTX-Video) is the first DiT-based video generation model capable of generating high-quality videos in real-time. It produces 24 FPS videos at a 768x512 resolution faster than they can be watched. Trained on a large-scale dataset of diverse videos, the model generates high-resolution videos with realistic and varied content. We provide a model for both text-to-video as well as image + text-to-video usecases.

<Tip>

Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.

</Tip>

## Loading Single Files

Loading the original LTX Video checkpoints is also possible with [`~ModelMixin.from_single_file`].

```python
import torch
from diffusers import AutoencoderKLLTXVideo, LTXImageToVideoPipeline, LTXVideoTransformer3DModel

single_file_url = "https://huggingface.co/Lightricks/LTX-Video/ltx-video-2b-v0.9.safetensors"
transformer = LTXVideoTransformer3DModel.from_single_file(single_file_url, torch_dtype=torch.bfloat16)
vae = AutoencoderKLLTXVideo.from_single_file(single_file_url, torch_dtype=torch.bfloat16)
pipe = LTXImageToVideoPipeline.from_pretrained("Lightricks/LTX-Video", transformer=transformer, vae=vae, torch_dtype=torch.bfloat16)

# ... inference code ...
```

Alternatively, the pipeline can be used to load the weights with [~FromSingleFileMixin.from_single_file`].

```python
import torch
from diffusers import LTXImageToVideoPipeline
from transformers import T5EncoderModel, T5Tokenizer

single_file_url = "https://huggingface.co/Lightricks/LTX-Video/ltx-video-2b-v0.9.safetensors"
text_encoder = T5EncoderModel.from_pretrained("Lightricks/LTX-Video", subfolder="text_encoder", torch_dtype=torch.bfloat16)
tokenizer = T5Tokenizer.from_pretrained("Lightricks/LTX-Video", subfolder="tokenizer", torch_dtype=torch.bfloat16)
pipe = LTXImageToVideoPipeline.from_single_file(single_file_url, text_encoder=text_encoder, tokenizer=tokenizer, torch_dtype=torch.bfloat16)
```

## LTXPipeline

[[autodoc]] LTXPipeline
- all
- __call__

## LTXImageToVideoPipeline

[[autodoc]] LTXImageToVideoPipeline
- all
- __call__

## LTXPipelineOutput

[[autodoc]] pipelines.ltx.pipeline_output.LTXPipelineOutput
209 changes: 209 additions & 0 deletions scripts/convert_ltx_to_diffusers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
import argparse
from typing import Any, Dict

import torch
from safetensors.torch import load_file
from transformers import T5EncoderModel, T5Tokenizer

from diffusers import AutoencoderKLLTXVideo, FlowMatchEulerDiscreteScheduler, LTXPipeline, LTXVideoTransformer3DModel


def remove_keys_(key: str, state_dict: Dict[str, Any]):
state_dict.pop(key)


TOKENIZER_MAX_LENGTH = 128

TRANSFORMER_KEYS_RENAME_DICT = {
"patchify_proj": "proj_in",
"adaln_single": "time_embed",
"q_norm": "norm_q",
"k_norm": "norm_k",
}

TRANSFORMER_SPECIAL_KEYS_REMAP = {}

VAE_KEYS_RENAME_DICT = {
# decoder
"up_blocks.0": "mid_block",
"up_blocks.1": "up_blocks.0",
"up_blocks.2": "up_blocks.1.upsamplers.0",
"up_blocks.3": "up_blocks.1",
"up_blocks.4": "up_blocks.2.conv_in",
"up_blocks.5": "up_blocks.2.upsamplers.0",
"up_blocks.6": "up_blocks.2",
"up_blocks.7": "up_blocks.3.conv_in",
"up_blocks.8": "up_blocks.3.upsamplers.0",
"up_blocks.9": "up_blocks.3",
# encoder
"down_blocks.0": "down_blocks.0",
"down_blocks.1": "down_blocks.0.downsamplers.0",
"down_blocks.2": "down_blocks.0.conv_out",
"down_blocks.3": "down_blocks.1",
"down_blocks.4": "down_blocks.1.downsamplers.0",
"down_blocks.5": "down_blocks.1.conv_out",
"down_blocks.6": "down_blocks.2",
"down_blocks.7": "down_blocks.2.downsamplers.0",
"down_blocks.8": "down_blocks.3",
"down_blocks.9": "mid_block",
# common
"conv_shortcut": "conv_shortcut.conv",
"res_blocks": "resnets",
"norm3.norm": "norm3",
"per_channel_statistics.mean-of-means": "latents_mean",
"per_channel_statistics.std-of-means": "latents_std",
}

VAE_SPECIAL_KEYS_REMAP = {
"per_channel_statistics.channel": remove_keys_,
"per_channel_statistics.mean-of-means": remove_keys_,
"per_channel_statistics.mean-of-stds": remove_keys_,
}


def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
state_dict = saved_dict
if "model" in saved_dict.keys():
state_dict = state_dict["model"]
if "module" in saved_dict.keys():
state_dict = state_dict["module"]
if "state_dict" in saved_dict.keys():
state_dict = state_dict["state_dict"]
return state_dict


def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
state_dict[new_key] = state_dict.pop(old_key)


def convert_transformer(
ckpt_path: str,
dtype: torch.dtype,
):
PREFIX_KEY = ""

original_state_dict = get_state_dict(load_file(ckpt_path))
transformer = LTXVideoTransformer3DModel().to(dtype=dtype)

for key in list(original_state_dict.keys()):
new_key = key[len(PREFIX_KEY) :]
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(original_state_dict, key, new_key)

for key in list(original_state_dict.keys()):
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)

transformer.load_state_dict(original_state_dict, strict=True)
return transformer


def convert_vae(ckpt_path: str, dtype: torch.dtype):
original_state_dict = get_state_dict(load_file(ckpt_path))
vae = AutoencoderKLLTXVideo().to(dtype=dtype)

for key in list(original_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(original_state_dict, key, new_key)

for key in list(original_state_dict.keys()):
for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items():
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)

vae.load_state_dict(original_state_dict, strict=True)
return vae


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
)
parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint")
parser.add_argument(
"--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory"
)
parser.add_argument(
"--typecast_text_encoder",
action="store_true",
default=False,
help="Whether or not to apply fp16/bf16 precision to text_encoder",
)
parser.add_argument("--save_pipeline", action="store_true")
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.")
return parser.parse_args()


DTYPE_MAPPING = {
"fp32": torch.float32,
"fp16": torch.float16,
"bf16": torch.bfloat16,
}

VARIANT_MAPPING = {
"fp32": None,
"fp16": "fp16",
"bf16": "bf16",
}


if __name__ == "__main__":
args = get_args()

transformer = None
dtype = DTYPE_MAPPING[args.dtype]
variant = VARIANT_MAPPING[args.dtype]

if args.save_pipeline:
assert args.transformer_ckpt_path is not None and args.vae_ckpt_path is not None

if args.transformer_ckpt_path is not None:
transformer: LTXVideoTransformer3DModel = convert_transformer(args.transformer_ckpt_path, dtype)
if not args.save_pipeline:
transformer.save_pretrained(
args.output_path, safe_serialization=True, max_shard_size="5GB", variant=variant
)

if args.vae_ckpt_path is not None:
vae: AutoencoderKLLTXVideo = convert_vae(args.vae_ckpt_path, dtype)
if not args.save_pipeline:
vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", variant=variant)

if args.save_pipeline:
text_encoder_id = "google/t5-v1_1-xxl"
tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)

if args.typecast_text_encoder:
text_encoder = text_encoder.to(dtype=dtype)

# Apparently, the conversion does not work anymore without this :shrug:
for param in text_encoder.parameters():
param.data = param.data.contiguous()

scheduler = FlowMatchEulerDiscreteScheduler(
use_dynamic_shifting=True,
base_shift=0.95,
max_shift=2.05,
base_image_seq_len=1024,
max_image_seq_len=4096,
shift_terminal=0.1,
)

pipe = LTXPipeline(
scheduler=scheduler,
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
)

pipe.save_pretrained(args.output_path, safe_serialization=True, variant=variant, max_shard_size="5GB")
Loading

0 comments on commit 96c376a

Please sign in to comment.