Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Magvit-v2 (vqvae) #656

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 139 additions & 0 deletions examples/magvit/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# MAGVIT-v2: Language Model Beats Diffusion -- Tokenizer is Key to Visual Generation

This folder contains the Mindspore implementation of [MAGVIT-v2](https://arxiv.org/pdf/2310.05737). Since the official implementation is **NOT open-sourced**, we refer to the following repository implementations:
- [MAGVIT-v1](https://github.com/google-research/magvit)
- [magvit2-pytorch](https://github.com/lucidrains/magvit2-pytorch)
- [vector-quantize-pytorch](https://github.com/lucidrains/vector-quantize-pytorch)

Thanks for their great work.

## Features

- [x] Lookup-Free-Quantization (LFQ)
- [x] VQVAE-2d Training
- [x] VQVAE-3d Training
- [x] VQGAN Training
- [ ] MAGVIT-v2 Transformers
- [ ] MAGVIT-v2 Training

## Requirements

### Mindspore + Ascend
- **Env**: `Python 3.8.18` and [`CANN 8.0.RC2.beta1`](https://www.hiascend.com/software/cann)
- **Main Dependencies**: [`Mindspore>=2.3`](https://www.mindspore.cn/)
- **Other Dependencies**: see in `requirements.txt`

#### Installation Tutorials:

1. Install Mindspore >=2.3 according to the [official tutorials](https://www.mindspore.cn/install)
2. Ascend users please install the corresponding *CANN 8.0.RC2.beta1* in [community edition](https://www.hiascend.com/developer/download/community/result?module=cann&cann=8.0.RC2.beta1) as well as the relevant driver and firmware packages in [firmware and driver](https://www.hiascend.com/hardware/firmware-drivers/community), as stated in the [official document](https://www.mindspore.cn/install/#%E5%AE%89%E8%A3%85%E6%98%87%E8%85%BEai%E5%A4%84%E7%90%86%E5%99%A8%E9%85%8D%E5%A5%97%E8%BD%AF%E4%BB%B6%E5%8C%85).
3. Install the pacakges listed in requirements.txt with `pip install -r requirements.txt`


## Datasets

Here we present an overview of the datasets we used in training. For data download and detailed preprocessing tutorial, please refer to [datasets](./tools/datasets.md)

### Image Dataset for pretraining

Following the original paper, we use [ImageNet-1K](https://huggingface.co/datasets/ILSVRC/imagenet-1k) to pretrain VQVAE-2d as the initialiation.

| Dataset | Train | Val |
| --- | --- | --- |
| ImageNet-1K | 1281167 | 50000 |


### Video Dataset

In this repositry, we use [UCF-101](https://www.crcv.ucf.edu/data/UCF101.php) to train the VQVAE-3d.

We use the Train/Test Splits for *Action Recognition*, the statistics are:

| Dataset | Train | Test |
| --- | --- | --- |
| UCF-101| 9537 | 3783 |


## Training

### 1. Visual Tokenizer: VQVAE

The training of VQVAE can be divided into two stages: VQVAE-2d and VQVAE-3d, where VQVAE-2d is the initialization of VQVAE-3d.

#### 1.1 2D Tokenizer

We pretrained a VQVAE-2d model using [ImageNet-1K](https://huggingface.co/datasets/ILSVRC/imagenet-1k), and the accuracy is as follows:

| Model | Token Type | #Tokens | Dataset | Image Size | Codebook Size | PSNR | SSIM |
|-------| -----------| --------| ------- | -----------| --------------| -----| -----|
| MAGVIT-v2 | 2D | 16x16 |ImageNet | 128x128 | 262144 | 20.013 | 0.5734 |

You can pretrain your model by following these steps:

1) Prepare datasets

We take ImageNet as an example. You can refer to [datasets-ImageNet](./tools/datasets.md#image-dataset-for-pretraining) to download and process the data.


2) Run the training script as below:

```
# standalone training
bash scripts/run_train_vqvae_2d.sh

# parallel training
bash scripts/run_train_vqvae_2d_parallel.sh
```


3) Inflate 2d to 3d

We provide a script for inflation, you can run the command:

```
python tools/inflate_vae2d_to_3d.py --src VQVAE_2D_MODEL_PATH --target INFALTED_MODEL_PATH
```

#### 1.2 3D Tokenizer

Modify the path of `--pretrained` VQVAE-2d model in [run_train_vqvae.sh](./scripts/run_train_vqvae.sh) / [run_train_vqvae_parallel.sh](./scripts/run_train_vqvae_parallel.sh)

Run the training script as below:

```
# standalone training
bash scripts/run_train_vqvae.sh

# parallel training
bash scripts/run_train_vqvae_parallel.sh
```

The VQVAE-3d model we trained is as follows:

| Model | Token Type | #Tokens | Dataset | Video Size | Codebook Size | PSNR | SSIM |
|-------| -----------| ------- | ------- | -----------| --------------| -----| -----|
| MAGVIT-v2 | 3D | 5x16x16 | UCF-101 | 17x128x128 | 262144 | 21.6529 | 0.7415 |


### 2. MAGVIT-v2 generation model

The training script of MAGVIT-v2 generation model is still under development, so stay tuned!


## Evaluation
We provide two common evaluation metrics in our implementations: PSNR and SSIM.
To run the evaluations, you can use the command: `bash scripts/run_eval_vqvae.sh`.

Please modify the `scripts/run_eval_vqvae.sh` accordingly as shown below:

```
# To evaluate 2D Tokenizer
--model_class vqvae-2d \
--data_path IMAGE_DATA_FOLDER \
--ckpt_path MODEL_PATH

# To evaluate 3D Tokenizer
--model_class vqvae-3d \
--data_path VIDEO_DATA_FOLDER \
--ckpt_path MODEL_PATH
```
13 changes: 13 additions & 0 deletions examples/magvit/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
opencv-python
scikit-image
ftfy
regex
albumentations
pillow==9.1.1
tqdm
mindcv
decord
omegaconf
pyyaml
ml-collections
imageio
244 changes: 244 additions & 0 deletions examples/magvit/scripts/eval_vqvae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
"""
Infer and evaluate VQVAE
"""

import argparse
import logging
import os
import sys
import time

import numpy as np
from PIL import Image
from tqdm import tqdm

import mindspore as ms

__dir__ = os.path.dirname(os.path.abspath(__file__))
mindone_lib_path = os.path.abspath(os.path.join(__dir__, "../../../"))
sys.path.insert(0, mindone_lib_path)
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "..")))

from videogvt.config.vqgan3d_ucf101_config import get_config
from videogvt.data.loader import create_dataloader
from videogvt.eval import calculate_psnr, calculate_ssim
from videogvt.models.vqvae import build_model
from videogvt.models.vqvae.lpips import LPIPS

from mindone.utils.config import str2bool
from mindone.utils.logger import set_logger

logger = logging.getLogger(__name__)


def _rearrange_in(x):
if x.ndim == 5:
b, c, t, h, w = x.shape
x = x.permute(0, 2, 1, 3, 4)
x = ms.ops.reshape(x, (b * t, c, h, w))
return x


def postprocess(x, trim=True):
pixels = (x + 1) * 127.5
pixels = np.clip(pixels, 0, 255).astype(np.uint8)
if pixels.ndim == 5:
# b, c, t, h, w -> b t c h w
return np.transpose(pixels, (0, 2, 1, 3, 4))
else:
return pixels


def visualize(recons, x=None, save_fn="tmp_vae_recons"):
# x: (b h w c), np array
for i in range(recons.shape[0]):
if x is not None:
out = np.concatenate((x[i], recons[i]), axis=-2)
else:
out = recons[i]
Image.fromarray(out).save(f"{save_fn}-{i:02d}.png")


def main(args):
ms.set_context(mode=args.mode, ascend_config={"precision_mode": "allow_mix_precision_bf16"})
set_logger(name="", output_dir=args.output_path, rank=0)

config = get_config()
dtype = {"fp32": ms.float32, "fp16": ms.float16, "bf16": ms.bfloat16}[args.dtype]

model = build_model(args.model_class, config, is_training=False, dtype=dtype)
param_dict = ms.load_checkpoint(args.ckpt_path)
ms.load_param_into_net(model, param_dict)
model.set_train(False)
logger.info(f"Loaded checkpoint from {args.ckpt_path}")

if args.eval_loss:
lpips_loss_fn = LPIPS()

ds_config = dict(
csv_path=args.csv_path,
data_folder=args.data_path,
size=args.size,
crop_size=args.crop_size,
sample_stride=args.frame_stride,
sample_n_frames=args.num_frames,
return_image=False,
flip=False,
random_crop=False,
)

ds_name = "video" if args.model_class == "vqvae-3d" else "image"
dataset = create_dataloader(
ds_config=ds_config,
batch_size=args.batch_size,
ds_name=ds_name,
num_parallel_workers=args.num_parallel_workers,
shuffle=False,
drop_remainder=False,
)
num_batches = dataset.get_dataset_size()

ds_iter = dataset.create_dict_iterator(1)

logger.info("Inferene begins")
mean_infer_time = 0
mean_psnr = 0
mean_ssim = 0
mean_lpips = 0
mean_recon = 0

psnr_list = []
ssim_list = []
for step, data in tqdm(enumerate(ds_iter)):
x = data[ds_name].to(dtype)
start_time = time.time()

recons = model._forward(x)

infer_time = time.time() - start_time
mean_infer_time += infer_time
logger.info(f"Infer time: {infer_time}")

generated_videos = postprocess(recons.float().asnumpy())
real_videos = postprocess(x.float().asnumpy())

psnr_scores = list(calculate_psnr(real_videos, generated_videos)["value"].values())
psnr_list += psnr_scores

ssim_scores = list(calculate_ssim(real_videos, generated_videos)["value"].values())
ssim_list += ssim_scores

if args.eval_loss:
recon_loss = np.abs((real_videos - generated_videos))
lpips_loss = lpips_loss_fn(_rearrange_in(x), _rearrange_in(recons)).asnumpy()
mean_recon += recon_loss.mean()
mean_lpips += lpips_loss.mean()

mean_psnr = np.mean(psnr_list)
mean_ssim = np.mean(ssim_list)

mean_infer_time /= num_batches
logger.info(f"Mean infer time: {mean_infer_time}")
logger.info(f"Done. Results saved in {args.output_path}")

logger.info(f"mean psnr:{mean_psnr:.4f}")
logger.info(f"mean ssim:{mean_ssim:.4f}")

if args.eval_loss:
mean_recon /= num_batches
mean_lpips /= num_batches
logger.info(f"mean recon loss: {mean_recon:.4f}")
logger.info(f"mean lpips loss: {mean_lpips:.4f}")


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--ckpt_path",
default="outputs/vae_train/ckpt/vae_kl_f8-e10.ckpt",
type=str,
help="checkpoint path",
)
parser.add_argument(
"--model_class",
default="vqvae-3d",
type=str,
choices=[
"vqvae-2d",
"vqvae-3d",
],
help="model arch type",
)
parser.add_argument(
"--csv_path",
default=None,
type=str,
help="path to csv annotation file. If None, will get images from the folder of `data_path`",
)
parser.add_argument("--data_path", default="dataset", type=str, help="data path")
parser.add_argument(
"--output_path",
default="samples/vae_recons",
type=str,
help="output directory to save inference results",
)
parser.add_argument("--size", default=384, type=int, help="image rescale size")
parser.add_argument("--crop_size", default=256, type=int, help="image crop size")
parser.add_argument("--num_frames", default=16, type=int, help="num frames")
parser.add_argument("--frame_stride", default=1, type=int, help="frame sampling stride")

parser.add_argument(
"--mode",
default=0,
type=int,
help="Specify the mode: 0 for graph mode, 1 for pynative mode",
)
parser.add_argument(
"--dtype",
default="fp32",
type=str,
choices=["fp32", "fp16", "bf16"],
help="mixed precision type, if fp32, all layer precision is float32 (amp_level=O0), \
if bf16 or fp16, amp_level==O2, part of layers will compute in bf16 or fp16 such as matmul, dense, conv.",
)
parser.add_argument("--device_target", type=str, default="Ascend", help="Ascend or GPU")
parser.add_argument("--batch_size", default=4, type=int, help="batch size")
parser.add_argument(
"--num_parallel_workers",
default=8,
type=int,
help="num workers for data loading",
)
parser.add_argument(
"--eval_loss",
default=False,
type=str2bool,
help="whether measure loss including reconstruction, kl, perceptual loss",
)
parser.add_argument(
"--save_images",
default=True,
type=str2bool,
help="whether save reconstructed images",
)
parser.add_argument(
"--encode_only",
default=False,
type=str2bool,
help="only encode to save z or distribution",
)
parser.add_argument(
"--save_z_dist",
default=False,
type=str2bool,
help="If True, save z distribution, mean and logvar. Otherwise, save z after sampling.",
)

args = parser.parse_args()

return args


if __name__ == "__main__":
args = parse_args()
main(args)
Loading