diff --git a/examples/magvit/README.md b/examples/magvit/README.md new file mode 100644 index 0000000000..f1f9cdc20b --- /dev/null +++ b/examples/magvit/README.md @@ -0,0 +1,139 @@ +# MAGVIT-v2: Language Model Beats Diffusion -- Tokenizer is Key to Visual Generation + +This folder contains the Mindspore implementation of [MAGVIT-v2](https://arxiv.org/pdf/2310.05737). Since the official implementation is **NOT open-sourced**, we refer to the following repository implementations: +- [MAGVIT-v1](https://github.com/google-research/magvit) +- [magvit2-pytorch](https://github.com/lucidrains/magvit2-pytorch) +- [vector-quantize-pytorch](https://github.com/lucidrains/vector-quantize-pytorch) + +Thanks for their great work. + +## Features + +- [x] Lookup-Free-Quantization (LFQ) +- [x] VQVAE-2d Training +- [x] VQVAE-3d Training +- [x] VQGAN Training +- [ ] MAGVIT-v2 Transformers +- [ ] MAGVIT-v2 Training + +## Requirements + +### Mindspore + Ascend +- **Env**: `Python 3.8.18` and [`CANN 8.0.RC2.beta1`](https://www.hiascend.com/software/cann) +- **Main Dependencies**: [`Mindspore>=2.3`](https://www.mindspore.cn/) +- **Other Dependencies**: see in `requirements.txt` + +#### Installation Tutorials: + +1. Install Mindspore >=2.3 according to the [official tutorials](https://www.mindspore.cn/install) +2. Ascend users please install the corresponding *CANN 8.0.RC2.beta1* in [community edition](https://www.hiascend.com/developer/download/community/result?module=cann&cann=8.0.RC2.beta1) as well as the relevant driver and firmware packages in [firmware and driver](https://www.hiascend.com/hardware/firmware-drivers/community), as stated in the [official document](https://www.mindspore.cn/install/#%E5%AE%89%E8%A3%85%E6%98%87%E8%85%BEai%E5%A4%84%E7%90%86%E5%99%A8%E9%85%8D%E5%A5%97%E8%BD%AF%E4%BB%B6%E5%8C%85). +3. Install the pacakges listed in requirements.txt with `pip install -r requirements.txt` + + +## Datasets + +Here we present an overview of the datasets we used in training. For data download and detailed preprocessing tutorial, please refer to [datasets](./tools/datasets.md) + +### Image Dataset for pretraining + +Following the original paper, we use [ImageNet-1K](https://huggingface.co/datasets/ILSVRC/imagenet-1k) to pretrain VQVAE-2d as the initialiation. + +| Dataset | Train | Val | +| --- | --- | --- | +| ImageNet-1K | 1281167 | 50000 | + + +### Video Dataset + +In this repositry, we use [UCF-101](https://www.crcv.ucf.edu/data/UCF101.php) to train the VQVAE-3d. + +We use the Train/Test Splits for *Action Recognition*, the statistics are: + +| Dataset | Train | Test | +| --- | --- | --- | +| UCF-101| 9537 | 3783 | + + +## Training + +### 1. Visual Tokenizer: VQVAE + +The training of VQVAE can be divided into two stages: VQVAE-2d and VQVAE-3d, where VQVAE-2d is the initialization of VQVAE-3d. + +#### 1.1 2D Tokenizer + +We pretrained a VQVAE-2d model using [ImageNet-1K](https://huggingface.co/datasets/ILSVRC/imagenet-1k), and the accuracy is as follows: + +| Model | Token Type | #Tokens | Dataset | Image Size | Codebook Size | PSNR | SSIM | +|-------| -----------| --------| ------- | -----------| --------------| -----| -----| +| MAGVIT-v2 | 2D | 16x16 |ImageNet | 128x128 | 262144 | 20.013 | 0.5734 | + +You can pretrain your model by following these steps: + +1) Prepare datasets + +We take ImageNet as an example. You can refer to [datasets-ImageNet](./tools/datasets.md#image-dataset-for-pretraining) to download and process the data. + + +2) Run the training script as below: + + ``` + # standalone training + bash scripts/run_train_vqvae_2d.sh + + # parallel training + bash scripts/run_train_vqvae_2d_parallel.sh + ``` + + +3) Inflate 2d to 3d + + We provide a script for inflation, you can run the command: + + ``` + python tools/inflate_vae2d_to_3d.py --src VQVAE_2D_MODEL_PATH --target INFALTED_MODEL_PATH + ``` + +#### 1.2 3D Tokenizer + +Modify the path of `--pretrained` VQVAE-2d model in [run_train_vqvae.sh](./scripts/run_train_vqvae.sh) / [run_train_vqvae_parallel.sh](./scripts/run_train_vqvae_parallel.sh) + +Run the training script as below: + + ``` + # standalone training + bash scripts/run_train_vqvae.sh + + # parallel training + bash scripts/run_train_vqvae_parallel.sh + ``` + + The VQVAE-3d model we trained is as follows: + +| Model | Token Type | #Tokens | Dataset | Video Size | Codebook Size | PSNR | SSIM | +|-------| -----------| ------- | ------- | -----------| --------------| -----| -----| +| MAGVIT-v2 | 3D | 5x16x16 | UCF-101 | 17x128x128 | 262144 | 21.6529 | 0.7415 | + + +### 2. MAGVIT-v2 generation model + +The training script of MAGVIT-v2 generation model is still under development, so stay tuned! + + +## Evaluation +We provide two common evaluation metrics in our implementations: PSNR and SSIM. +To run the evaluations, you can use the command: `bash scripts/run_eval_vqvae.sh`. + +Please modify the `scripts/run_eval_vqvae.sh` accordingly as shown below: + +``` +# To evaluate 2D Tokenizer +--model_class vqvae-2d \ +--data_path IMAGE_DATA_FOLDER \ +--ckpt_path MODEL_PATH + +# To evaluate 3D Tokenizer +--model_class vqvae-3d \ +--data_path VIDEO_DATA_FOLDER \ +--ckpt_path MODEL_PATH +``` diff --git a/examples/magvit/requirements.txt b/examples/magvit/requirements.txt new file mode 100644 index 0000000000..339bd3711d --- /dev/null +++ b/examples/magvit/requirements.txt @@ -0,0 +1,13 @@ +opencv-python +scikit-image +ftfy +regex +albumentations +pillow==9.1.1 +tqdm +mindcv +decord +omegaconf +pyyaml +ml-collections +imageio diff --git a/examples/magvit/scripts/eval_vqvae.py b/examples/magvit/scripts/eval_vqvae.py new file mode 100644 index 0000000000..b1512e738d --- /dev/null +++ b/examples/magvit/scripts/eval_vqvae.py @@ -0,0 +1,244 @@ +""" +Infer and evaluate VQVAE +""" + +import argparse +import logging +import os +import sys +import time + +import numpy as np +from PIL import Image +from tqdm import tqdm + +import mindspore as ms + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +mindone_lib_path = os.path.abspath(os.path.join(__dir__, "../../../")) +sys.path.insert(0, mindone_lib_path) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, ".."))) + +from videogvt.config.vqgan3d_ucf101_config import get_config +from videogvt.data.loader import create_dataloader +from videogvt.eval import calculate_psnr, calculate_ssim +from videogvt.models.vqvae import build_model +from videogvt.models.vqvae.lpips import LPIPS + +from mindone.utils.config import str2bool +from mindone.utils.logger import set_logger + +logger = logging.getLogger(__name__) + + +def _rearrange_in(x): + if x.ndim == 5: + b, c, t, h, w = x.shape + x = x.permute(0, 2, 1, 3, 4) + x = ms.ops.reshape(x, (b * t, c, h, w)) + return x + + +def postprocess(x, trim=True): + pixels = (x + 1) * 127.5 + pixels = np.clip(pixels, 0, 255).astype(np.uint8) + if pixels.ndim == 5: + # b, c, t, h, w -> b t c h w + return np.transpose(pixels, (0, 2, 1, 3, 4)) + else: + return pixels + + +def visualize(recons, x=None, save_fn="tmp_vae_recons"): + # x: (b h w c), np array + for i in range(recons.shape[0]): + if x is not None: + out = np.concatenate((x[i], recons[i]), axis=-2) + else: + out = recons[i] + Image.fromarray(out).save(f"{save_fn}-{i:02d}.png") + + +def main(args): + ms.set_context(mode=args.mode, ascend_config={"precision_mode": "allow_mix_precision_bf16"}) + set_logger(name="", output_dir=args.output_path, rank=0) + + config = get_config() + dtype = {"fp32": ms.float32, "fp16": ms.float16, "bf16": ms.bfloat16}[args.dtype] + + model = build_model(args.model_class, config, is_training=False, dtype=dtype) + param_dict = ms.load_checkpoint(args.ckpt_path) + ms.load_param_into_net(model, param_dict) + model.set_train(False) + logger.info(f"Loaded checkpoint from {args.ckpt_path}") + + if args.eval_loss: + lpips_loss_fn = LPIPS() + + ds_config = dict( + csv_path=args.csv_path, + data_folder=args.data_path, + size=args.size, + crop_size=args.crop_size, + sample_stride=args.frame_stride, + sample_n_frames=args.num_frames, + return_image=False, + flip=False, + random_crop=False, + ) + + ds_name = "video" if args.model_class == "vqvae-3d" else "image" + dataset = create_dataloader( + ds_config=ds_config, + batch_size=args.batch_size, + ds_name=ds_name, + num_parallel_workers=args.num_parallel_workers, + shuffle=False, + drop_remainder=False, + ) + num_batches = dataset.get_dataset_size() + + ds_iter = dataset.create_dict_iterator(1) + + logger.info("Inferene begins") + mean_infer_time = 0 + mean_psnr = 0 + mean_ssim = 0 + mean_lpips = 0 + mean_recon = 0 + + psnr_list = [] + ssim_list = [] + for step, data in tqdm(enumerate(ds_iter)): + x = data[ds_name].to(dtype) + start_time = time.time() + + recons = model._forward(x) + + infer_time = time.time() - start_time + mean_infer_time += infer_time + logger.info(f"Infer time: {infer_time}") + + generated_videos = postprocess(recons.float().asnumpy()) + real_videos = postprocess(x.float().asnumpy()) + + psnr_scores = list(calculate_psnr(real_videos, generated_videos)["value"].values()) + psnr_list += psnr_scores + + ssim_scores = list(calculate_ssim(real_videos, generated_videos)["value"].values()) + ssim_list += ssim_scores + + if args.eval_loss: + recon_loss = np.abs((real_videos - generated_videos)) + lpips_loss = lpips_loss_fn(_rearrange_in(x), _rearrange_in(recons)).asnumpy() + mean_recon += recon_loss.mean() + mean_lpips += lpips_loss.mean() + + mean_psnr = np.mean(psnr_list) + mean_ssim = np.mean(ssim_list) + + mean_infer_time /= num_batches + logger.info(f"Mean infer time: {mean_infer_time}") + logger.info(f"Done. Results saved in {args.output_path}") + + logger.info(f"mean psnr:{mean_psnr:.4f}") + logger.info(f"mean ssim:{mean_ssim:.4f}") + + if args.eval_loss: + mean_recon /= num_batches + mean_lpips /= num_batches + logger.info(f"mean recon loss: {mean_recon:.4f}") + logger.info(f"mean lpips loss: {mean_lpips:.4f}") + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--ckpt_path", + default="outputs/vae_train/ckpt/vae_kl_f8-e10.ckpt", + type=str, + help="checkpoint path", + ) + parser.add_argument( + "--model_class", + default="vqvae-3d", + type=str, + choices=[ + "vqvae-2d", + "vqvae-3d", + ], + help="model arch type", + ) + parser.add_argument( + "--csv_path", + default=None, + type=str, + help="path to csv annotation file. If None, will get images from the folder of `data_path`", + ) + parser.add_argument("--data_path", default="dataset", type=str, help="data path") + parser.add_argument( + "--output_path", + default="samples/vae_recons", + type=str, + help="output directory to save inference results", + ) + parser.add_argument("--size", default=384, type=int, help="image rescale size") + parser.add_argument("--crop_size", default=256, type=int, help="image crop size") + parser.add_argument("--num_frames", default=16, type=int, help="num frames") + parser.add_argument("--frame_stride", default=1, type=int, help="frame sampling stride") + + parser.add_argument( + "--mode", + default=0, + type=int, + help="Specify the mode: 0 for graph mode, 1 for pynative mode", + ) + parser.add_argument( + "--dtype", + default="fp32", + type=str, + choices=["fp32", "fp16", "bf16"], + help="mixed precision type, if fp32, all layer precision is float32 (amp_level=O0), \ + if bf16 or fp16, amp_level==O2, part of layers will compute in bf16 or fp16 such as matmul, dense, conv.", + ) + parser.add_argument("--device_target", type=str, default="Ascend", help="Ascend or GPU") + parser.add_argument("--batch_size", default=4, type=int, help="batch size") + parser.add_argument( + "--num_parallel_workers", + default=8, + type=int, + help="num workers for data loading", + ) + parser.add_argument( + "--eval_loss", + default=False, + type=str2bool, + help="whether measure loss including reconstruction, kl, perceptual loss", + ) + parser.add_argument( + "--save_images", + default=True, + type=str2bool, + help="whether save reconstructed images", + ) + parser.add_argument( + "--encode_only", + default=False, + type=str2bool, + help="only encode to save z or distribution", + ) + parser.add_argument( + "--save_z_dist", + default=False, + type=str2bool, + help="If True, save z distribution, mean and logvar. Otherwise, save z after sampling.", + ) + + args = parser.parse_args() + + return args + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/examples/magvit/scripts/run_eval_vqvae.sh b/examples/magvit/scripts/run_eval_vqvae.sh new file mode 100644 index 0000000000..4917ae6558 --- /dev/null +++ b/examples/magvit/scripts/run_eval_vqvae.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +export MS_ENABLE_ACLNN=1 +export GRAPH_OP_RUN=1 +export DEVICE_ID=7 + +python scripts/eval_vqvae.py \ + --model_class vqvae-3d \ + --data_path ./datasets/ucf101/test/ \ + --output_path vqvae_eval \ + --size 128 \ + --crop_size 128 \ + --num_frames 17 \ + --frame_stride 1 \ + --batch_size 16 \ + --mode 1 \ + --ckpt_path outputs/vqvae_3d/ckpt/vqvae.ckpt \ + --dtype fp32 \ + --eval_loss True \ diff --git a/examples/magvit/scripts/run_train_vqvae.sh b/examples/magvit/scripts/run_train_vqvae.sh new file mode 100644 index 0000000000..662e40ff15 --- /dev/null +++ b/examples/magvit/scripts/run_train_vqvae.sh @@ -0,0 +1,43 @@ +#!/bin/bash + +export MS_ENABLE_ACLNN=1 +export GRAPH_OP_RUN=1 +export DEVICE_ID=7 + +python scripts/train_vqvae.py \ + --model_class vqvae-3d \ + --pretrained ./model_weights/vqvae2d-lfq-128-init.ckpt \ + --use_discriminator True \ + --use_ema True \ + --dataset_name video \ + --data_path ./datasets/ucf101/train/ \ + --num_frames 17 \ + --frame_stride 1 \ + --size 128 \ + --crop_size 128 \ + --num_parallel_workers 1 \ + --drop_overflow_update True \ + --batch_size 1 \ + --epochs 60 \ + --log_interval 400 \ + --ckpt_save_interval 1 \ + --gradient_accumulation_steps 16 \ + --clip_grad True \ + --max_grad_norm 1.0 \ + --optim adamw \ + --betas 0.99 \ + --weight_decay 0.01 \ + --warmup_steps 1000 \ + --base_learning_rate 2.0e-05 \ + --end_learning_rate 1.0e-07 \ + --scale_lr False \ + --init_loss_scale 1024 \ + --loss_scaler_type dynamic \ + --scale_window 50000 \ + --dtype fp32 \ + --global_bf16 True \ + --mode 0 \ + --debug False \ + --seed 1234 \ + --output_path outputs/vqvae_3d/ + > train_3d.log 2>&1 & diff --git a/examples/magvit/scripts/run_train_vqvae_2d.sh b/examples/magvit/scripts/run_train_vqvae_2d.sh new file mode 100644 index 0000000000..c91b60b3fa --- /dev/null +++ b/examples/magvit/scripts/run_train_vqvae_2d.sh @@ -0,0 +1,39 @@ +#!/bin/bash + +export MS_ENABLE_ACLNN=1 +export GRAPH_OP_RUN=1 +export DEVICE_ID=7 + +python scripts/train_vqvae.py \ + --model_class vqvae-2d \ + --use_discriminator False \ + --use_ema False \ + --dataset_name image \ + --data_path ./datasets/ImageNet/train/ \ + --size 128 \ + --crop_size 128 \ + --num_parallel_workers 1 \ + --drop_overflow_update True \ + --batch_size 1 \ + --epochs 2 \ + --log_interval 400 \ + --ckpt_save_interval 1 \ + --gradient_accumulation_steps 8 \ + --clip_grad True \ + --max_grad_norm 1.0 \ + --optim adamw \ + --betas 0.99 \ + --weight_decay 0.01 \ + --warmup_steps 1000 \ + --base_learning_rate 2.0e-05 \ + --end_learning_rate 1.0e-07 \ + --scale_lr False \ + --init_loss_scale 1024 \ + --loss_scaler_type dynamic \ + --scale_window 1000 \ + --dtype fp32 \ + --global_bf16 True \ + --mode 0 \ + --seed 1234 \ + --output_path outputs/vqvae_2d/ + > train_2d.log 2>&1 & diff --git a/examples/magvit/scripts/run_train_vqvae_2d_parallel.sh b/examples/magvit/scripts/run_train_vqvae_2d_parallel.sh new file mode 100644 index 0000000000..8986704e85 --- /dev/null +++ b/examples/magvit/scripts/run_train_vqvae_2d_parallel.sh @@ -0,0 +1,57 @@ +#!/bin/bash + +export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +# improve data loading performance for distributed training: 1 +# export MS_ENABLE_NUMA=0 +# plot memory usage, feature/model: 1 +# export MS_MEMORY_STATISTIC=0 + +# export MS_DATASET_SINK_QUEUE=4 + +# operation/graph fusion for dynamic shape +# export MS_DEV_ENABLE_KERNEL_PACKET=on + +# enable kbk : 1 +export MS_ENABLE_ACLNN=1 +export GRAPH_OP_RUN=1 + +# log level +export GLOG_v=2 + +output_dir=outputs/vqvae_2d/ + +msrun --bind_core=True --master_port=8090 --worker_num=8 --local_worker_num=8 --log_dir=$output_dir \ + python scripts/train_vqvae.py \ + --model_class vqvae-2d \ + --use_discriminator False \ + --use_parallel True \ + --use_ema False \ + --dataset_name image \ + --data_path ./datasets/ImageNet/train/ \ + --size 128 \ + --crop_size 128 \ + --num_parallel_workers 1 \ + --drop_overflow_update True \ + --batch_size 1 \ + --epochs 2 \ + --log_interval 400 \ + --ckpt_save_interval 1 \ + --gradient_accumulation_steps 8 \ + --clip_grad True \ + --max_grad_norm 1.0 \ + --optim adamw \ + --betas 0.99 \ + --weight_decay 0.01 \ + --warmup_steps 1000 \ + --base_learning_rate 2.0e-05 \ + --end_learning_rate 1.0e-07 \ + --scale_lr False \ + --init_loss_scale 1024 \ + --loss_scaler_type dynamic \ + --scale_window 1000 \ + --dtype fp32 \ + --global_bf16 True \ + --mode 0 \ + --seed 1234 \ + --output_path outputs/vqvae_2d/ + > train_2d.log 2>&1 & diff --git a/examples/magvit/scripts/run_train_vqvae_parallel.sh b/examples/magvit/scripts/run_train_vqvae_parallel.sh new file mode 100644 index 0000000000..6dc8520f54 --- /dev/null +++ b/examples/magvit/scripts/run_train_vqvae_parallel.sh @@ -0,0 +1,60 @@ +#!/bin/bash + +export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +# improve data loading performance for distributed training: 1 +# export MS_ENABLE_NUMA=0 +# plot memory usage, feature/model: 1 +# export MS_MEMORY_STATISTIC=0 + +# export MS_DATASET_SINK_QUEUE=4 + +# operation/graph fusion for dynamic shape +# export MS_DEV_ENABLE_KERNEL_PACKET=on + +# enable kbk : 1 +export MS_ENABLE_ACLNN=1 +export GRAPH_OP_RUN=1 + +# log level +export GLOG_v=2 + +output_dir=outputs/vqvae_3d/ + +msrun --bind_core=True --master_port=8090 --worker_num=8 --local_worker_num=8 --log_dir=$output_dir \ + python scripts/train_vqvae.py \ + --model_class vqvae-3d \ + --pretrained ./model_weights/vqvae2d-lfq-128-init.ckpt \ + --use_discriminator True \ + --use_parallel True \ + --use_ema True \ + --dataset_name video \ + --data_path ./datasets/ucf101/train/ \ + --num_frames 17 \ + --frame_stride 1 \ + --size 128 \ + --crop_size 128 \ + --num_parallel_workers 1 \ + --drop_overflow_update True \ + --batch_size 1 \ + --epochs 60 \ + --log_interval 400 \ + --ckpt_save_interval 1 \ + --gradient_accumulation_steps 16 \ + --clip_grad True \ + --max_grad_norm 1.0 \ + --optim adamw \ + --betas 0.99 \ + --weight_decay 0.01 \ + --warmup_steps 1000 \ + --base_learning_rate 2.0e-05 \ + --end_learning_rate 1.0e-07 \ + --scale_lr False \ + --init_loss_scale 1024 \ + --loss_scaler_type dynamic \ + --scale_window 50000 \ + --dtype fp32 \ + --global_bf16 True \ + --mode 0 \ + --debug False \ + --seed 1234 \ + --output_path outputs/vqvae_3d/ diff --git a/examples/magvit/scripts/train_vqvae.py b/examples/magvit/scripts/train_vqvae.py new file mode 100644 index 0000000000..642e7c17ec --- /dev/null +++ b/examples/magvit/scripts/train_vqvae.py @@ -0,0 +1,460 @@ +"""Trainer for VQVAE""" + +import logging +import os +import shutil +import sys +import time + +import yaml + +import mindspore as ms +from mindspore import Model, nn +from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell +from mindspore.train.callback import TimeMonitor + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +mindone_lib_path = os.path.abspath(os.path.join(__dir__, "../../../")) +sys.path.insert(0, mindone_lib_path) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, ".."))) + +from mindcv.optim import create_optimizer +from utils.env import init_env, set_all_reduce_fusion +from videogvt.config.vqgan3d_ucf101_config import get_config +from videogvt.config.vqvae_train_args import parse_args +from videogvt.data.loader import create_dataloader +from videogvt.models.vqvae import StyleGANDiscriminator, build_model +from videogvt.models.vqvae.net_with_loss import DiscriminatorWithLoss, GeneratorWithLoss + +from mindone.trainers.callback import EvalSaveCallback, OverflowMonitor, ProfilerCallback +from mindone.trainers.checkpoint import CheckpointManager, resume_train_network +from mindone.trainers.ema import EMA +from mindone.trainers.lr_schedule import create_scheduler + +# from mindone.trainers.optim import create_optimizer +from mindone.trainers.train_step import TrainOneStepWrapper +from mindone.utils.amp import auto_mixed_precision +from mindone.utils.logger import set_logger +from mindone.utils.params import count_params + +os.environ["HCCL_CONNECT_TIMEOUT"] = "6000" +os.environ["MS_ASCEND_CHECK_OVERFLOW_MODE"] = "INFNAN_MODE" + +logger = logging.getLogger(__name__) + + +def create_loss_scaler(loss_scaler_type, init_loss_scale, loss_scale_factor=2, scale_window=1000): + if args.loss_scaler_type == "dynamic": + loss_scaler = DynamicLossScaleUpdateCell( + loss_scale_value=init_loss_scale, + scale_factor=loss_scale_factor, + scale_window=scale_window, + ) + elif args.loss_scaler_type == "static": + loss_scaler = nn.FixedLossScaleUpdateCell(init_loss_scale) + else: + raise ValueError + + return loss_scaler + + +def main(args): + # 1. init + rank_id, device_num = init_env( + args.mode, + seed=args.seed, + distributed=args.use_parallel, + device_target=args.device_target, + max_device_memory=args.max_device_memory, + parallel_mode=args.parallel_mode, + jit_level=args.jit_level, + global_bf16=args.global_bf16, + debug=args.debug, + ) + + set_logger( + name="", + output_dir=args.output_path, + rank=rank_id, + log_level=eval(args.log_level), + ) + + # 2. build models + # vqvae (G) + model_config = get_config() + dtype = {"fp32": ms.float32, "fp16": ms.float16, "bf16": ms.bfloat16}[args.dtype] + vqvae = build_model(args.model_class, model_config, is_training=True, pretrained=args.pretrained, dtype=dtype) + + # discriminator (D) + use_discriminator = args.use_discriminator and (model_config.lr_configs.disc_weight > 0.0) + + if args.use_discriminator and (model_config.lr_configs.disc_weight <= 0.0): + logging.warning("use_discriminator is True but disc_weight is 0.") + + if use_discriminator: + crop_size = int(args.crop_size) + frame_size = int(args.num_frames) + disc = StyleGANDiscriminator(model_config.discriminator, crop_size, crop_size, frame_size, dtype=dtype) + else: + disc = None + + # mixed precision + # TODO: set softmax, sigmoid computed in FP32. manually set inside network since they are ops, instead of layers whose precision will be set by AMP level. + if args.dtype not in ["fp16", "bf16"]: + amp_level = "O2" + if not args.global_bf16: + vqvae = auto_mixed_precision( + vqvae, + amp_level=auto_mixed_precision( + vqvae, + amp_level=amp_level, + dtype=dtype, + custom_fp32_cells=[nn.GroupNorm] if args.vae_keep_gn_fp32 else [], + ), + ) + else: + amp_level = "O0" + + # 3. build net with loss (core) + # G with loss + vqvae_with_loss = GeneratorWithLoss( + vqvae, + discriminator=disc, + is_video=(args.dataset_name == "video"), + **model_config.lr_configs, + dtype=dtype, + ) + disc_start = model_config.lr_configs.disc_start + + # D with loss + if use_discriminator: + disc_with_loss = DiscriminatorWithLoss(vqvae, disc, disc_start) + + tot_params, trainable_params = count_params(vqvae_with_loss) + logger.info("Total params {:,}; Trainable params {:,}".format(tot_params, trainable_params)) + + # 4. build dataset + ds_config = dict( + csv_path=args.csv_path, + data_folder=args.data_path, + size=args.size, + crop_size=args.crop_size, + random_crop=args.random_crop, + flip=args.flip, + ) + if args.dataset_name == "video": + ds_config.update( + dict( + sample_stride=args.frame_stride, + sample_n_frames=args.num_frames, + return_image=False, + ) + ) + assert not ( + # model_config.generator.params.ddconfig.split_time_upsample + args.num_frames % 2 == 0 + and False + ), "num of frames must be odd if split_time_upsample is True" + else: + ds_config.update(dict(expand_dim_t=args.expand_dim_t)) + dataset = create_dataloader( + ds_config=ds_config, + batch_size=args.batch_size, + ds_name=args.dataset_name, + num_parallel_workers=args.num_parallel_workers, + shuffle=args.shuffle, + device_num=device_num, + rank_id=rank_id, + ) + dataset_size = dataset.get_dataset_size() + + # 5. build training utils + # torch scale lr by: model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr + if args.scale_lr: + learning_rate = args.base_learning_rate * args.batch_size * args.gradient_accumulation_steps * device_num + else: + learning_rate = args.base_learning_rate + + total_train_steps = args.epochs * dataset_size + + if not args.decay_steps: + args.decay_steps = max(1, total_train_steps - args.warmup_steps) + + lr = create_scheduler( + steps_per_epoch=dataset_size, + name=args.scheduler, + lr=learning_rate, + end_lr=args.end_learning_rate, + warmup_steps=args.warmup_steps, + decay_steps=args.decay_steps, + num_epochs=args.epochs, + ) + + set_all_reduce_fusion( + vqvae_with_loss.trainable_params(), + split_num=7, + distributed=args.use_parallel, + parallel_mode=args.parallel_mode, + ) + + # build optimizer + update_logvar = False # in torch, vqvae_with_loss.logvar is not updated. + if update_logvar: + vqvae_params_to_update = [ + vqvae_with_loss.vqvae.trainable_params(), + vqvae_with_loss.logvar, + ] + else: + vqvae_params_to_update = vqvae_with_loss.vqvae.trainable_params() + + optim_vqvae = create_optimizer( + vqvae_params_to_update, + opt=args.optim, + weight_decay=args.weight_decay, + lr=lr, + eps=1e-08, + beta1=0.9, + beta2=0.999, + weight_decay_filter="norm_and_bias", + ) + + loss_scaler_vqvae = create_loss_scaler( + args.loss_scaler_type, + args.init_loss_scale, + args.loss_scale_factor, + args.scale_window, + ) + + if use_discriminator: + optim_disc = create_optimizer( + disc_with_loss.discriminator.trainable_params(), + opt=args.optim, + weight_decay=args.weight_decay, + lr=lr, + eps=1e-08, + beta1=0.9, + beta2=0.999, + weight_decay_filter="norm_and_bias", + ) + + loss_scaler_disc = create_loss_scaler( + args.loss_scaler_type, + args.init_loss_scale, + args.loss_scale_factor, + args.scale_window, + ) + + ema = ( + EMA( + vqvae_with_loss.vqvae, + ema_decay=args.ema_decay, + offloading=False, + ).to_float(dtype) + if args.use_ema + else None + ) + + # resume training states + # TODO: resume Discriminator if used + ckpt_dir = os.path.join(args.output_path, "ckpt") + os.makedirs(ckpt_dir, exist_ok=True) + start_epoch = 0 + if args.resume: + resume_ckpt = os.path.join(ckpt_dir, "train_resume.ckpt") if isinstance(args.resume, bool) else args.resume + + start_epoch, loss_scale, cur_iter, last_overflow_iter = resume_train_network( + vqvae_with_loss, optim_vqvae, resume_ckpt + ) + loss_scaler_vqvae.loss_scale_value = loss_scale + loss_scaler_vqvae.cur_iter = cur_iter + loss_scaler_vqvae.last_overflow_iter = last_overflow_iter + logger.info(f"Resume training from {resume_ckpt}") + + # training step + training_step_vqvae = TrainOneStepWrapper( + vqvae_with_loss, + optimizer=optim_vqvae, + scale_sense=loss_scaler_vqvae, + drop_overflow_update=args.drop_overflow_update, + gradient_accumulation_steps=args.gradient_accumulation_steps, + clip_grad=args.clip_grad, + clip_norm=args.max_grad_norm, + ema=ema, + ) + + if use_discriminator: + training_step_disc = TrainOneStepWrapper( + disc_with_loss, + optimizer=optim_disc, + scale_sense=loss_scaler_disc, + drop_overflow_update=args.drop_overflow_update, + gradient_accumulation_steps=args.gradient_accumulation_steps, + clip_grad=args.clip_grad, + clip_norm=args.max_grad_norm, + ema=None, # No ema for disriminator + ) + + if rank_id == 0: + key_info = "Key Settings:\n" + "=" * 50 + "\n" + key_info += "\n".join( + [ + f"MindSpore mode[GRAPH(0)/PYNATIVE(1)]: {args.mode}", + f"Distributed mode: {args.use_parallel}", + f"amp level: {amp_level}", + f"dtype: {args.dtype}", + f"Data path: {args.data_path}", + f"Learning rate: {learning_rate}", + f"Batch size: {args.batch_size}", + f"Rescale size: {args.size}", + f"Crop size: {args.crop_size}", + f"Weight decay: {args.weight_decay}", + f"Grad accumulation steps: {args.gradient_accumulation_steps}", + f"Num epochs: {args.epochs}", + f"Loss scaler: {args.loss_scaler_type}", + f"Init loss scale: {args.init_loss_scale}", + f"Grad clipping: {args.clip_grad}", + f"Max grad norm: {args.max_grad_norm}", + f"EMA: {args.use_ema}", + ] + ) + key_info += "\n" + "=" * 50 + logger.info(key_info) + + # 6. training process + # backup config files + args.config = "videogvt/config/vqgan3d_ucf101_config.py" + shutil.copyfile(args.config, os.path.join(args.output_path, os.path.basename(args.config))) + with open(os.path.join(args.output_path, "args.yaml"), "w") as f: + yaml.safe_dump(vars(args), stream=f, default_flow_style=False, sort_keys=False) + + if not use_discriminator: + if args.global_bf16: + model = Model(training_step_vqvae, amp_level="O0") + else: + model = Model(training_step_vqvae) + + # callbacks + callback = [TimeMonitor(args.log_interval)] + ofm_cb = OverflowMonitor() + callback.append(ofm_cb) + + if rank_id == 0: + save_cb = EvalSaveCallback( + network=vqvae_with_loss.vqvae, + rank_id=rank_id, + ckpt_save_dir=ckpt_dir, + ema=ema, + ckpt_save_policy="latest_k", + ckpt_max_keep=args.ckpt_max_keep, + ckpt_save_interval=args.ckpt_save_interval, + log_interval=args.log_interval, + start_epoch=start_epoch, + model_name="vqvae_3d", + record_lr=False, + ) + callback.append(save_cb) + if args.profile: + callback.append(ProfilerCallback()) + + logger.info("Start training...") + + model.train( + args.epochs, + dataset, + callbacks=callback, + dataset_sink_mode=args.dataset_sink_mode, + # sink_size=args.sink_size, + initial_epoch=start_epoch, + ) + + else: + if rank_id == 0: + ckpt_manager = CheckpointManager(ckpt_dir, "latest_k", k=args.ckpt_max_keep) + + # output_numpy=True ? + ds_iter = dataset.create_dict_iterator(args.epochs - start_epoch) + bp_steps = 0 + + logger.info("Start training...") + for epoch in range(start_epoch, args.epochs): + epoch_loss = 0.0 + avg_loss = 0.0 + start_time_e = time.time() + + for step, data in enumerate(ds_iter): + start_time_s = time.time() + x = data[args.dataset_name].to(dtype) + + cur_global_step = epoch * dataset_size + step + 1 + + # NOTE: inputs must match the order in GeneratorWithLoss.construct + loss_vqvae_t, overflow, scaling_sens = training_step_vqvae(x) + loss_disc_t, overflow_d, scaling_sens_d = training_step_disc(x) + + if overflow: + logger.warning(f"Overflow occurs in step {cur_global_step}") + + # loss + loss = float(loss_vqvae_t.asnumpy()) + float(loss_disc_t.asnumpy()) + avg_loss += loss + epoch_loss += loss + + # log + step_time = time.time() - start_time_s + if (step + 1) % args.log_interval == 0: + avg_loss /= float(args.log_interval) + logger.info( + f"E: {epoch+1}, S: {step+1}, Loss vqvae avg: {avg_loss:.4f}, Step time: {step_time*1000:.2f}ms" + ) + avg_loss = 0.0 + bp_steps += 1 + + if rank_id == 0 and args.step_mode: + cur_epoch = epoch + 1 + if (cur_global_step % args.ckpt_save_interval == 0) or (cur_global_step == total_train_steps): + ckpt_name = f"vqvae-s{cur_global_step}.ckpt" + if ema is not None: + ema.swap_before_eval() + vqvae_with_loss.set_train(False) + disc_with_loss.set_train(False) + ckpt_manager.save(vqvae_with_loss.vqvae, None, ckpt_name=ckpt_name, append_dict=None) + + if ema is not None: + ema.swap_after_eval() + vqvae_with_loss.set_train(True) + disc_with_loss.set_train(True) + + if cur_global_step == total_train_steps: + break + + epoch_cost = time.time() - start_time_e + per_step_time = epoch_cost / dataset_size + cur_epoch = epoch + 1 + epoch_loss /= dataset_size + logger.info( + f"Epoch:[{int(cur_epoch):>3d}/{int(args.epochs):>3d}], loss avg: {epoch_loss:.4f}," + f"epoch time:{epoch_cost:.2f}s, per step time:{per_step_time*1000:.2f}ms, " + ) + + if rank_id == 0 and args.step_mode: + cur_epoch = epoch + 1 + if (cur_global_step % args.ckpt_save_interval == 0) or (cur_global_step == total_train_steps): + ckpt_name = f"vqvae-e{cur_epoch}.ckpt" + if ema is not None: + ema.swap_before_eval() + vqvae_with_loss.set_train(False) + disc_with_loss.set_train(False) + ckpt_manager.save(vqvae_with_loss.vqvae, None, ckpt_name=ckpt_name, append_dict=None) + + if ema is not None: + ema.swap_after_eval() + vqvae_with_loss.set_train(True) + disc_with_loss.set_train(True) + + if cur_global_step == total_train_steps: + break + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/examples/magvit/tools/datasets.md b/examples/magvit/tools/datasets.md new file mode 100644 index 0000000000..1a3dbf9943 --- /dev/null +++ b/examples/magvit/tools/datasets.md @@ -0,0 +1,51 @@ +# Datasets + +Here we present the download and detailed preprocessing tutorial for the data we used in our training. + + +## Image Dataset for pretraining + +Following the original paper, we use [ImageNet-1K](https://www.image-net.org/challenges/LSVRC/index.php) to pretrain VQVAE-2D as the initialiation. + +| Dataset | Train | Val | +| --- | --- | --- | +| ImageNet-1K | 1281167 | 50000 | + +### Download +You can download through the link on [Datasets: ILSVRC/imagenet-1k](https://huggingface.co/datasets/ILSVRC/imagenet-1k). + +### Preprocessing +After downloading, please unzip the folders and rearrange them into train and validation folders: + +``` +├─ImageNet +│ ├─train +| | ├─n01440764 +| | | ├─n01440764_10026.JPEG +| | | ├─n01440764_10027.JPEG +| │ | └─ ... +| | ├─n01443537 +| | | ├─n01443537_10026.JPEG +| | | ├─n01443537_10027.JPEG +| │ | └─ ... +│ └─validation +| ├─val_00000001.JPEG +| ├─val_00000002.JPEG +| └─ ... +``` + +## Video Dataset + +In this repositry, we use [UCF-101](https://www.crcv.ucf.edu/data/UCF101.php) to train the VQVAE-3d. + +We use the Train/Test Splits for Action Recognition, the statistics are: + +| Dataset | Train | Test | +| --- | --- | --- | +| UCF-101| 9537 | 3783 | + +### Download +You can download the dataset and *The Train/Test Splits for Action Recognition* on the page of [UCF-101](https://www.crcv.ucf.edu/data/UCF101.php) + +### Preprocessing +After downloading, please split the dataset into train and test according to the Splits text file. We also provide a script for the seperation, you can refer to [ucf101.py](./ucf101.py) diff --git a/examples/magvit/tools/inflate_vae2d_to_3d.py b/examples/magvit/tools/inflate_vae2d_to_3d.py new file mode 100644 index 0000000000..d4d3805394 --- /dev/null +++ b/examples/magvit/tools/inflate_vae2d_to_3d.py @@ -0,0 +1,98 @@ +import argparse +import os +import sys + +import mindspore as ms +from mindspore import context + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, ".."))) + +from videogvt.config.vqgan3d_ucf101_config import get_config +from videogvt.models.vqvae import build_model + +context.set_context(mode=1, device_target="Ascend", device_id=7) + + +def inflate(vae2d_ckpt, save_fp): + model_config = get_config() + dtype = ms.float32 + vae3d = build_model("vqvae-3d", model_config, is_training=False, dtype=dtype) + vae2d = ms.load_checkpoint(vae2d_ckpt) + + vae_2d_keys = list(vae2d.keys()) + vae_3d_keys = list(vae3d.parameters_dict().keys()) + + # 3d -> 2d + map_dict = { + "conv.weight": "weight", + "conv.bias": "bias", + } + + new_state_dict = {} + + for key_3d in vae_3d_keys: + if key_3d.startswith("loss"): + continue + + # param name mapping from vae-3d to vae-2d + key_2d = "vqvae." + key_3d + # if not 'attn' in key_2d: + for kw in map_dict: + key_2d = key_2d.replace(kw, map_dict[kw]) + + assert key_2d in vae_2d_keys, f"Key {key_2d} ({key_3d}) not found in 2D VAE" + + # set vae 3d state dict + shape_3d = vae3d.parameters_dict()[key_3d].shape + shape_2d = vae2d[key_2d].shape + if "bias" in key_2d: + assert shape_3d == shape_2d, f"Shape mismatch for key {key_3d} ({key_2d})" + new_state_dict[key_3d] = vae2d[key_2d] + elif "norm" in key_2d: + assert shape_3d == shape_2d, f"Shape mismatch for key {key_3d} ({key_2d})" + new_state_dict[key_3d] = vae2d[key_2d] + elif "conv" in key_2d or "nin_shortcut" in key_2d: + if shape_3d[:2] != shape_2d[:2]: + print(key_2d, shape_3d, shape_2d) + + weights = vae2d[key_2d] + shape_3d_new = tuple([shape_3d[0] // 2]) + shape_3d[1:] + new_w = ms.ops.zeros(shape_3d_new, dtype=weights.dtype) + new_w[:, :, -1, :, :] = weights + new_w = new_w.repeat(2, 0) + + new_w = ms.Parameter(new_w, name=key_3d) + new_state_dict[key_3d] = new_w + + else: + w = vae2d[key_2d] + new_w = ms.ops.zeros(shape_3d, dtype=w.dtype) + # tail initialization + new_w[:, :, -1, :, :] = w # cin, cout, t, h, w + + new_w = ms.Parameter(new_w, name=key_3d) + new_state_dict[key_3d] = new_w + + elif "attn_1" in key_2d: + new_val = vae2d[key_2d].expand_dims(axis=2) + new_param = ms.Parameter(new_val, name=key_3d) + new_state_dict[key_3d] = new_param + else: + raise NotImplementedError(f"Key {key_3d} ({key_2d}) not implemented") + + ms.save_checkpoint(new_state_dict, save_fp) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--src", type=str, default=None, help="path to mindspore vae 2d checkpoint") + parser.add_argument( + "--target", + type=str, + default="models/causal_vae_488_init.ckpt", + help="target file path to save the inflated checkpoint", + ) + args = parser.parse_args() + + inflate(args.src, args.target) diff --git a/examples/magvit/tools/ucf101.py b/examples/magvit/tools/ucf101.py new file mode 100644 index 0000000000..90100a9eb4 --- /dev/null +++ b/examples/magvit/tools/ucf101.py @@ -0,0 +1,41 @@ +import shutil + + +def read_video_list(fn, is_train=False): + video_list = [] + with open(fn, "r") as f: + lines = f.read().split("\n") + for line in lines: + if line == "": + continue + if is_train: + v_name = line.split(" ")[0].strip() + else: + v_name = line.strip() + video_list.append(v_name) + + return video_list + + +fn_train = "ucfTrainTestlist/trainlist01.txt" +fn_test = "ucfTrainTestlist/testlist01.txt" + +# train +train_list = read_video_list(fn_train, True) +dir_src = "ucf101/fullset/UCF-101/" +dir_des = "ucf101/rec_train/" +for vname in train_list: + path_src = dir_src + vname + path_des = dir_des + vname.split("/")[-1] + shutil.move(path_src, path_des) + print(f"Moved {path_src} to {path_des}.") + +# test +test_list = read_video_list(fn_test, False) +dir_src = "ucf101/fullset/UCF-101/" +dir_des = "ucf101/rec_test/" +for vname in test_list: + path_src = dir_src + vname + path_des = dir_des + vname.split("/")[-1] + shutil.move(path_src, path_des) + print(f"Moved {path_src} to {path_des}.") diff --git a/examples/magvit/utils/env.py b/examples/magvit/utils/env.py new file mode 100644 index 0000000000..3a8774d567 --- /dev/null +++ b/examples/magvit/utils/env.py @@ -0,0 +1,115 @@ +import logging +from typing import Tuple + +import mindspore as ms +from mindspore.communication.management import get_group_size, get_rank, init + +from mindone.utils.seed import set_random_seed + +logger = logging.getLogger(__name__) + + +def init_env( + mode: int = ms.GRAPH_MODE, + seed: int = 42, + distributed: bool = False, + max_device_memory: str = None, + device_target: str = "Ascend", + parallel_mode: str = "data", + jit_level: str = "O0", + global_bf16: bool = False, + debug: bool = False, +) -> Tuple[int, int]: + """ + Initialize MindSpore environment. + + Args: + mode: MindSpore execution mode. Default is 0 (ms.GRAPH_MODE). + seed: The seed value for reproducibility. Default is 42. + distributed: Whether to enable distributed training. Default is False. + Returns: + A tuple containing the device ID, rank ID and number of devices. + """ + set_random_seed(seed) + + if debug and mode == ms.GRAPH_MODE: # force PyNative mode when debugging + logger.warning("Debug mode is on, switching execution mode to PyNative.") + mode = ms.PYNATIVE_MODE + + if max_device_memory is not None: + ms.set_context(max_device_memory=max_device_memory) + + if distributed: + ms.set_context( + mode=mode, + device_target=device_target, + ) + if parallel_mode == "optim": + print("use optim parallel") + ms.set_auto_parallel_context( + parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL, + enable_parallel_optimizer=True, + ) + init() + device_num = get_group_size() + rank_id = get_rank() + else: + init() + device_num = get_group_size() + rank_id = get_rank() + logger.debug(f"rank_id: {rank_id}, device_num: {device_num}") + ms.reset_auto_parallel_context() + + ms.set_auto_parallel_context( + parallel_mode=ms.ParallelMode.DATA_PARALLEL, + gradients_mean=True, + device_num=device_num, + ) + + var_info = ["device_num", "rank_id", "device_num / 8", "rank_id / 8"] + var_value = [device_num, rank_id, int(device_num / 8), int(rank_id / 8)] + logger.info(dict(zip(var_info, var_value))) + + else: + device_num = 1 + rank_id = 0 + ms.set_context( + mode=mode, + device_target=device_target, + pynative_synchronize=debug, + ) + + try: + if jit_level in ["O0", "O1", "O2"]: + ms.set_context(jit_config={"jit_level": jit_level}) + else: + logger.warning( + f"Unsupported jit_level: {jit_level}. The framework will automatically select the execution mode." + ) + except Exception: + logger.warning( + "The current jit_level is not suitable because current MindSpore version or mode does not match," + "please ensure the MindSpore version >= ms2.3_0615, and use GRAPH_MODE." + ) + + if global_bf16: + ms.set_context(ascend_config={"precision_mode": "allow_mix_precision_bf16"}) + + return rank_id, device_num + + +def set_all_reduce_fusion( + params, + split_num: int = 7, + distributed: bool = False, + parallel_mode: str = "data", +) -> None: + """Set allreduce fusion strategy by split_num.""" + + if distributed and parallel_mode == "data": + all_params_num = len(params) + step = all_params_num // split_num + split_list = [i * step for i in range(1, split_num)] + split_list.append(all_params_num - 1) + logger.info(f"Distribute config set: dall_params_num: {all_params_num}, set all_reduce_fusion: {split_list}") + ms.set_auto_parallel_context(all_reduce_fusion_config=split_list) diff --git a/examples/magvit/videogvt/config/vqgan3d_ucf101_config.py b/examples/magvit/videogvt/config/vqgan3d_ucf101_config.py new file mode 100644 index 0000000000..e72197f325 --- /dev/null +++ b/examples/magvit/videogvt/config/vqgan3d_ucf101_config.py @@ -0,0 +1,64 @@ +r"""Configs for the VQGAN-3D on the UCF101. + +""" + +import ml_collections + + +def get_config(): + """Returns the base experiment configuration.""" + + config = ml_collections.ConfigDict() + + # Model: vqvae + config.vqvae = ml_collections.ConfigDict() + config.vqvae.channels = 3 + config.vqvae.embedding_dim = 18 + config.vqvae.codebook_size = 262144 # 2^18 + config.vqvae.filters = 128 + config.vqvae.activation_fn = "swish" + config.vqvae.num_enc_res_blocks = 4 + config.vqvae.num_dec_res_blocks = 4 + config.vqvae.channel_multipliers = (1, 2, 2, 4) + config.vqvae.spatial_downsample = (True, True, True) + config.vqvae.temporal_downsample = (True, True, False) + config.vqvae.num_groups = 32 + config.vqvae.num_frames = 17 + + config.discriminator = ml_collections.ConfigDict() + config.discriminator.filters = config.vqvae.get_oneway_ref("filters") + config.discriminator.channel_multipliers = (2, 4, 4, 4, 4) + + # Loss + config.lr_configs = ml_collections.ConfigDict() + config.lr_configs.perceptual_weight = 0.1 + config.lr_configs.entropy_weight = 0.1 + config.lr_configs.commit_weight = 0.25 + config.lr_configs.recons_weight = 10.0 + config.lr_configs.disc_weight = 0.1 + config.lr_configs.disc_start = 1 + + # LFQ + config.lfq = ml_collections.ConfigDict() + config.lfq.dim = config.vqvae.embedding_dim + config.lfq.codebook_size = config.vqvae.codebook_size + config.lfq.entropy_loss_weight = config.lr_configs.entropy_weight + config.lfq.commitment_loss_weight = config.lr_configs.commit_weight + config.lfq.diversity_gamma = 1.0 + config.lfq.straight_through_activation = "identity" + config.lfq.num_codebooks = 1 + config.lfq.keep_num_codebooks_dim = None + config.lfq.codebook_scale = 1.0 # for residual LFQ, codebook scaled down by 2x at each layer + config.lfq.frac_per_sample_entropy = ( + 1.0 # make less than 1. to only use a random fraction of the probs for per sample entropy + ) + config.lfq.inv_temperature = 100.0 + config.lfq.soft_clamp_input_value = None + config.lfq.cosine_sim_project_in = False + config.lfq.cosine_sim_project_in_scale = None + + # Pretrained models on ImageNet. + config.init_from = ml_collections.ConfigDict() + config.init_from.inflation = "2d->3d" + + return config diff --git a/examples/magvit/videogvt/config/vqvae_train_args.py b/examples/magvit/videogvt/config/vqvae_train_args.py new file mode 100644 index 0000000000..ee18dfc057 --- /dev/null +++ b/examples/magvit/videogvt/config/vqvae_train_args.py @@ -0,0 +1,264 @@ +import argparse +import logging +import os + +import yaml + +from mindone.utils.config import str2bool + +logger = logging.getLogger() + + +def _check_cfgs_in_parser(cfgs: dict, parser: argparse.ArgumentParser): + actions_dest = [action.dest for action in parser._actions] + defaults_key = parser._defaults.keys() + cfg_keys = list(cfgs.keys()) + for k in cfg_keys: + if k not in actions_dest and k not in defaults_key: + raise KeyError(f"{k} does not exist in ArgumentParser!") + cfgs.pop(k) + return cfgs + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--config", + default="", + type=str, + help="path to load a config yaml file that describes the training recipes which will override the default arguments", + ) + parser.add_argument( + "--model_class", + default="vqvae-3d", + type=str, + choices=[ + "vqvae-2d", + "vqvae-3d", + ], + help="model arch type", + ) + parser.add_argument( + "--pretrained", + default=None, + type=str, + help="path to pretrained autoencoder checkpoint", + ) + parser.add_argument("--use_parallel", default=False, type=str2bool, help="use parallel") + parser.add_argument( + "--parallel_mode", default="data", type=str, choices=["data", "optim"], help="parallel mode: data/optim" + ) + parser.add_argument("--debug", default=False, type=str2bool, help="debug mode") + parser.add_argument( + "--output_path", + default="outputs/vae_train", + type=str, + help="output directory to save training results", + ) + parser.add_argument( + "--resume", + default=False, + type=str, + help="resume training, can set True or path to resume checkpoint.(default=False)", + ) + + # ms + parser.add_argument( + "--mode", + default=0, + type=int, + help="Specify the mode: 0 for graph mode, 1 for pynative mode", + ) + parser.add_argument("--device_target", type=str, default="Ascend", help="Ascend or GPU") + parser.add_argument("--max_device_memory", type=str, default=None, help="e.g. `30GB` for 910a, `59GB` for 910b") + parser.add_argument("--profile", default=False, type=str2bool, help="Profile or not") + parser.add_argument( + "--jit_level", + default="O0", + type=str, + choices=["O0", "O1", "O2"], + help="Used to control the compilation optimization level. Supports [“O0”, “O1”, “O2”]." + "O0: Except for optimizations that may affect functionality, all other optimizations are turned off, adopt KernelByKernel execution mode." + "O1: Using commonly used optimizations and automatic operator fusion optimizations, adopt KernelByKernel execution mode." + "O2: Ultimate performance optimization, adopt Sink execution mode.", + ) + parser.add_argument( + "--vae_keep_gn_fp32", + default=True, + type=str2bool, + help="whether keep GroupNorm in fp32.", + ) + parser.add_argument( + "--global_bf16", + default=False, + type=str2bool, + help="Experimental. If True, dtype will be overrided, operators will be computered in bf16 if they are supported by CANN", + ) + # data + parser.add_argument( + "--dataset_name", + default="image", + type=str, + choices=["image", "video"], + help="dataset name, image or video", + ) + parser.add_argument("--data_path", default="dataset", type=str, help="data path") + parser.add_argument("--csv_path", default=None, type=str, help="path to csv annotation file") + parser.add_argument("--dataset_sink_mode", default=False, type=str2bool, help="sink mode") + parser.add_argument("--shuffle", default=True, type=str2bool, help="data shuffle") + parser.add_argument( + "--num_parallel_workers", + default=8, + type=int, + help="num workers for data loading", + ) + parser.add_argument("--size", default=384, type=int, help="image rescale size") + parser.add_argument("--crop_size", default=256, type=int, help="image crop size") + parser.add_argument( + "--random_crop", + default=False, + type=str2bool, + help="random crop for data augmentation", + ) + parser.add_argument( + "--flip", + default=False, + type=str2bool, + help="horizontal flip for data augmentation", + ) + parser.add_argument( + "--expand_dim_t", + default=False, + type=str2bool, + help="expand temporal axis for image data, used for vae 3d training with image data", + ) + parser.add_argument("--num_frames", default=17, type=int, help="num frames") + parser.add_argument("--frame_stride", default=1, type=int, help="frame sampling stride") + + # optim + parser.add_argument( + "--use_discriminator", + default=False, + type=str2bool, + help="Phase 1 training does not use discriminator, set False to reduce memory cost in graph mode.", + ) + parser.add_argument( + "--dtype", + default="fp32", + type=str, + choices=["fp32", "fp16", "bf16"], + help="mixed precision type, if fp32, all layer precision is float32 (amp_level=O0), \ + if bf16 or fp16, amp_level==O2, part of layers will compute in bf16 or fp16 such as matmul, dense, conv.", + ) + parser.add_argument("--optim", default="adam", type=str, help="optimizer") + parser.add_argument( + "--betas", + type=float, + default=(0.5, 0.9), # [0.9, 0.999] + help="Specify the [beta1, beta2] parameter for the Adam or AdamW optimizer.", + ) + parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay.") + parser.add_argument( + "--group_strategy", + type=str, + default="norm_and_bias", + help="Grouping strategy for weight decay. If `norm_and_bias`, weight decay filter list is [beta, gamma, bias]. \ + If None, filter list is [layernorm, bias]. Default: norm_and_bias", + ) + parser.add_argument("--seed", default=3407, type=int, help="random seed") + parser.add_argument("--warmup_steps", default=1000, type=int, help="warmup steps") + parser.add_argument("--batch_size", default=4, type=int, help="batch size") + parser.add_argument("--log_interval", default=1, type=int, help="log interval") + parser.add_argument( + "--base_learning_rate", + default=4.5e-06, + type=float, + help="base learning rate, can be scaled by global batch size", + ) + parser.add_argument( + "--end_learning_rate", + default=1e-8, + type=float, + help="The end learning rate for Adam.", + ) + parser.add_argument( + "--scale_lr", + default=True, + type=str2bool, + help="scale base-lr by ngpu * batch_size * n_accumulate", + ) + parser.add_argument("--decay_steps", default=0, type=int, help="lr decay steps.") + parser.add_argument( + "--scheduler", + default="cosine_decay", + type=str, + help="scheduler. option: constant, cosine_decay, ", + ) + parser.add_argument("--epochs", default=10, type=int, help="epochs") + parser.add_argument("--loss_scaler_type", default="static", type=str, help="dynamic or static") + parser.add_argument("--init_loss_scale", default=1024, type=float, help="loss scale") + parser.add_argument("--loss_scale_factor", default=2, type=float, help="loss scale factor") + parser.add_argument("--scale_window", default=1000, type=float, help="scale window") + parser.add_argument( + "--gradient_accumulation_steps", + default=1, + type=int, + help="gradient accumulation steps", + ) + parser.add_argument("--use_ema", default=False, type=str2bool, help="whether use EMA") + parser.add_argument("--ema_decay", default=0.9999, type=float, help="EMA decay") + parser.add_argument( + "--clip_grad", + default=False, + type=str2bool, + help="whether apply gradient clipping", + ) + parser.add_argument( + "--drop_overflow_update", + default=True, + type=str2bool, + help="drop overflow update", + ) + parser.add_argument( + "--max_grad_norm", + default=1.0, + type=float, + help="max gradient norm for clipping, effective when `clip_grad` enabled.", + ) + parser.add_argument( + "--ckpt_max_keep", + default=10, + type=int, + help="Maximum number of checkpoints to keep", + ) + parser.add_argument( + "--ckpt_save_interval", + default=1, + type=int, + help="save checkpoint every this epochs or steps", + ) + parser.add_argument( + "--step_mode", + default=False, + type=str2bool, + help="whether save ckpt by steps. If False, save ckpt by epochs.", + ) + parser.add_argument( + "--log_level", + type=str, + default="logging.INFO", + help="log level, options: logging.DEBUG, logging.INFO, logging.WARNING, logging.ERROR", + ) + + abs_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "")) + default_args = parser.parse_args() + if default_args.config: + default_args.config = os.path.join(abs_path, "../../", default_args.config) + with open(default_args.config, "r") as f: + cfg = yaml.safe_load(f) + cfg = _check_cfgs_in_parser(cfg, parser) + parser.set_defaults(**cfg) + args = parser.parse_args() + + logger.info(args) + return args diff --git a/examples/magvit/videogvt/data/image_dataset.py b/examples/magvit/videogvt/data/image_dataset.py new file mode 100644 index 0000000000..77e3ce8628 --- /dev/null +++ b/examples/magvit/videogvt/data/image_dataset.py @@ -0,0 +1,205 @@ +import copy +import csv +import glob +import logging +import os + +import albumentations +import cv2 +import numpy as np +from PIL import Image + +import mindspore as ms + +logger = logging.getLogger() + + +def create_image_transforms( + size=384, + crop_size=256, + interpolation="bicubic", + backend="al", + random_crop=False, + flip=False, +): + if backend == "pt": + from torchvision import transforms + from torchvision.transforms.functional import InterpolationMode + + mapping = { + "bilinear": InterpolationMode.BILINEAR, + "bicubic": InterpolationMode.BICUBIC, + } + + pixel_transforms = transforms.Compose( + [ + transforms.Resize(size, interpolation=mapping[interpolation]), + transforms.CenterCrop((crop_size, crop_size)), + transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), + ] + ) + else: + # expect rgb image in range 0-255, shape (h w c) + from albumentations import CenterCrop, HorizontalFlip, Normalize, RandomCrop, SmallestMaxSize + + mapping = {"bilinear": cv2.INTER_LINEAR, "bicubic": cv2.INTER_CUBIC} + transforms = [ + SmallestMaxSize(max_size=size, interpolation=mapping[interpolation]), + (CenterCrop(crop_size, crop_size) if not random_crop else RandomCrop(crop_size, crop_size)), + Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), + ] + if flip: + transforms += [HorizontalFlip(p=0.5)] + + pixel_transforms = albumentations.Compose(transforms) + + return pixel_transforms + + +def get_image_path_list(folder): + # TODO: find recursively + fmts = ["jpg", "png", "jpeg", "JPEG"] + out = [] + for fmt in fmts: + out += glob.glob(os.path.join(folder, f"*.{fmt}")) + if len(out) == 0: + for fmt in fmts: + out += glob.glob(os.path.join(folder, f"*/*.{fmt}")) + return sorted(out) + + +class ImageDataset: + def __init__( + self, + csv_path=None, + data_folder=None, + size=384, + crop_size=256, + random_crop=False, + flip=False, + image_column="file_name", + expand_dim_t=False, + **kwargs, + ): + if csv_path is not None: + with open(csv_path, "r") as csvfile: + self.dataset = list(csv.DictReader(csvfile)) + self.read_from_csv = True + else: + self.dataset = get_image_path_list(data_folder) + self.read_from_csv = False + self.length = len(self.dataset) + + logger.info(f"Num data samples: {self.length}") + + self.data_folder = data_folder + + self.transform_backend = "al" # pt, al + self.pixel_transforms = create_image_transforms( + size, + crop_size, + random_crop=random_crop, + flip=flip, + backend=self.transform_backend, + ) + self.image_column = image_column + self.expand_dim_t = expand_dim_t + + # prepare replacement data + # max_attempts = 100 + # self.prev_ok_sample = self.get_replace_data(max_attempts) + # self.require_update_prev = False + + def get_replace_data(self, max_attempts=100): + replace_data = None + attempts = min(max_attempts, self.length) + for idx in range(attempts): + try: + pixel_values, caption = self.read_sample(idx) + replace_data = copy.deepcopy((pixel_values, caption)) + break + except Exception as e: + print("\tError msg: {}".format(e), flush=True) + + assert replace_data is not None, f"Fail to preload sample in {attempts} attempts." + + return replace_data + + def read_sample(self, idx): + if self.read_from_csv: + image_dict = self.dataset[idx] + # first column is image path + image_fn = image_dict[list(image_dict.keys())[0]] + image_path = os.path.join(self.data_folder, image_fn) + else: + image_path = self.dataset[idx] + + image = Image.open(image_path).convert("RGB") + image = np.array(image) + + return image + + def __len__(self): + return self.length + + def __getitem__(self, idx): + # try: + image = self.read_sample(idx) + """ + if (self.prev_ok_sample is None) or (self.require_update_prev): + self.prev_ok_sample = copy.deepcopy(image) + self.require_update_prev = False + except Exception as e: + logger.warning(f"Fail to get sample of idx {idx}. The corrupted video will be replaced.") + print("\tError msg: {}".format(e), flush=True) + assert self.prev_ok_sample is not None + image = self.prev_ok_sample # unless the first sample is already not ok + self.require_update_prev = True + + if idx >= self.length: + raise IndexError # needed for checking the end of dataset iteration + """ + + if self.transform_backend == "pt": + import torch + + pixel_values = torch.from_numpy(image).permute(2, 0, 1).contiguous() + pixel_values = self.pixel_transforms(pixel_values) + trans_image = pixel_values.numpy() + out_image = trans_image.astype(np.float32) + elif self.transform_backend == "al": + trans_image = self.pixel_transforms(image=image)["image"] + out_image = trans_image.astype(np.float32) + out_image = out_image.transpose((2, 0, 1)) # h w c -> c h w + + if self.expand_dim_t: + # c h w -> c t h w + out_image = np.expand_dims(out_image, axis=1) + + return out_image + + +def check_sanity(x, save_fp="./tmp.png"): + # reverse normalization and visulaize the transformed video + if len(x.shape) == 4: + print("only save the first image") + x = x[0] + x = np.transpose(x, (1, 2, 0)) + + x = (x + 1.0) / 2.0 # -1,1 -> 0,1 + x = (x * 255).astype(np.uint8) + + if isinstance(x, ms.Tensor): + x = x.asnumpy() + Image.fromarray(x).save(save_fp) + + +if __name__ == "__main__": + ds_config = dict( + csv_path="/home/mindocr/yx/datasets/chinese_art_blip/train/metadata.csv", + data_folder="/home/mindocr/yx/datasets/chinese_art_blip/train", + ) + # test source dataset + ds = ImageDataset(**ds_config) + sample = ds.__getitem__(0) + print(sample.shape) diff --git a/examples/magvit/videogvt/data/loader.py b/examples/magvit/videogvt/data/loader.py new file mode 100644 index 0000000000..d10aeed488 --- /dev/null +++ b/examples/magvit/videogvt/data/loader.py @@ -0,0 +1,95 @@ +import mindspore as ms + +from .image_dataset import ImageDataset +from .video_dataset import VideoDataset + + +def create_dataloader( + ds_config, + batch_size, + ds_name="image", + num_parallel_workers=12, + max_rowsize=32, + shuffle=True, + device_num=1, + rank_id=0, + drop_remainder=True, +): + """ + Args: + ds_config, dataset config, args for ImageDataset or VideoDataset + ds_name: dataset name, image or video + """ + if ds_name == "image": + dataset = ImageDataset(**ds_config) + elif ds_name == "video": + dataset = VideoDataset(**ds_config) + print("Total number of samples: ", len(dataset)) + + # Larger value leads to more memory consumption. Default: 16 + # prefetch_size = config.get("prefetch_size", 16) + # ms.dataset.config.set_prefetch_size(prefetch_size) + + dataloader = ms.dataset.GeneratorDataset( + source=dataset, + column_names=[ds_name], + num_shards=device_num, + shard_id=rank_id, + python_multiprocessing=True, + shuffle=shuffle, + num_parallel_workers=num_parallel_workers, + max_rowsize=max_rowsize, + ) + + dl = dataloader.batch( + batch_size, + drop_remainder=drop_remainder, + ) + + return dl + + +if __name__ == "__main__": + import math + import time + + from tqdm import tqdm + + ds_config = dict( + csv_path="/home/mindocr/yx/datasets/chinese_art_blip/train/metadata.csv", + data_folder="/home/mindocr/yx/datasets/chinese_art_blip/train", + ) + + # test loader + dl = create_dataloader(ds_config, 4) + + num_batches = dl.get_dataset_size() + # ms.set_context(mode=0) + print(num_batches) + + steps = 50 + iterator = dl.create_dict_iterator(100) # create 100 repeats + tot = 0 + + progress_bar = tqdm(range(steps)) + progress_bar.set_description("Steps") + + start = time.time() + for epoch in range(math.ceil(steps / num_batches)): + for i, batch in enumerate(iterator): + print("epoch", epoch, "step", i) + dur = time.time() - start + tot += dur + + if epoch * num_batches + i < 2: + for k in batch: + print(k, batch[k].shape, batch[k].dtype) # , batch[k].min(), batch[k].max()) + print(f"time cost: {dur * 1000} ms") + + progress_bar.update(1) + if i + 1 > steps: # in case the data size is too large + break + start = time.time() + + mean = tot / steps + print("Avg batch loading time: ", mean) diff --git a/examples/magvit/videogvt/data/video_dataset.py b/examples/magvit/videogvt/data/video_dataset.py new file mode 100644 index 0000000000..92394644fd --- /dev/null +++ b/examples/magvit/videogvt/data/video_dataset.py @@ -0,0 +1,238 @@ +import copy +import csv +import glob +import logging +import os +import random + +import albumentations +import cv2 +import imageio +import numpy as np +from decord import VideoReader +from PIL import Image, ImageSequence + +logger = logging.getLogger() + + +def read_gif(gif_path, mode="RGB"): + with Image.open(gif_path) as fp: + frames = np.array([np.array(frame.convert(mode)) for frame in ImageSequence.Iterator(fp)]) + return frames + + +def create_video_transforms( + size=384, + crop_size=256, + interpolation="bicubic", + backend="al", + random_crop=False, + flip=False, + num_frames=None, +): + if backend == "al": + # expect rgb image in range 0-255, shape (h w c) + from albumentations import CenterCrop, HorizontalFlip, RandomCrop, SmallestMaxSize + + # NOTE: to ensure augment all frames in a video in the same way. + assert num_frames is not None, "num_frames must be parsed" + targets = {"image{}".format(i): "image" for i in range(num_frames)} + mapping = {"bilinear": cv2.INTER_LINEAR, "bicubic": cv2.INTER_CUBIC} + transforms = [ + SmallestMaxSize(max_size=size, interpolation=mapping[interpolation]), + (CenterCrop(crop_size, crop_size) if not random_crop else RandomCrop(crop_size, crop_size)), + ] + if flip: + transforms += [HorizontalFlip(p=0.5)] + + pixel_transforms = albumentations.Compose( + transforms, + additional_targets=targets, + ) + else: + raise NotImplementedError + + return pixel_transforms + + +def get_video_path_list(folder): + # TODO: find recursively + fmts = ["avi", "mp4", "gif"] + out = [] + for fmt in fmts: + out += glob.glob(os.path.join(folder, f"*.{fmt}")) + return sorted(out) + + +class VideoDataset: + def __init__( + self, + csv_path=None, + data_folder=None, + size=384, + crop_size=256, + random_crop=False, + flip=False, + sample_stride=4, + sample_n_frames=16, + return_image=False, + transform_backend="al", + video_column="video", + disable_flip=False, + ): + logger.info(f"loading annotations from {csv_path} ...") + + if csv_path is not None: + with open(csv_path, "r") as csvfile: + self.dataset = list(csv.DictReader(csvfile)) + self.read_from_csv = True + else: + self.dataset = get_video_path_list(data_folder) + self.read_from_csv = False + + self.length = len(self.dataset) + logger.info(f"Num data samples: {self.length}") + logger.info(f"sample_n_frames: {sample_n_frames}") + + self.data_folder = data_folder + self.sample_stride = sample_stride + self.sample_n_frames = sample_n_frames + self.return_image = return_image + + self.pixel_transforms = create_video_transforms( + size=size, + crop_size=crop_size, + random_crop=random_crop, + flip=flip, + num_frames=sample_n_frames, + ) + self.transform_backend = transform_backend + self.video_column = video_column + + # prepare replacement data + max_attempts = 100 + self.prev_ok_sample = self.get_replace_data(max_attempts) + self.require_update_prev = False + + def get_replace_data(self, max_attempts=100): + replace_data = None + attempts = min(max_attempts, self.length) + for idx in range(attempts): + try: + pixel_values = self.get_batch(idx) + replace_data = copy.deepcopy(pixel_values) + break + except Exception as e: + print("\tError msg: {}".format(e)) + + assert replace_data is not None, f"Fail to preload sample in {attempts} attempts." + + return replace_data + + def get_batch(self, idx): + # get video raw pixels (batch of frame) and its caption + if self.read_from_csv: + video_dict = self.dataset[idx] + video_fn = video_dict[list(video_dict.keys())[0]] + video_path = os.path.join(self.data_folder, video_fn) + else: + video_path = self.dataset[idx] + + if video_path.endswith(".gif"): + video_reader = read_gif(video_path, mode="RGB") + else: + video_reader = VideoReader(video_path) + + video_length = len(video_reader) + + if not self.return_image: + clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1) + start_idx = random.randint(0, video_length - clip_length) + batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int) + else: + batch_index = [random.randint(0, video_length - 1)] + + if video_path.endswith(".gif"): + pixel_values = video_reader[batch_index] # shape: (f, h, w, c) + else: + pixel_values = video_reader.get_batch(batch_index).asnumpy() # shape: (f, h, w, c) + + del video_reader + + return pixel_values + + def __len__(self): + return self.length + + def __getitem__(self, idx): + """ + Returns: + video: preprocessed video frames in shape (f, c, h, w), normalized to [-1, 1] + """ + try: + pixel_values = self.get_batch(idx) + if (self.prev_ok_sample is None) or (self.require_update_prev): + self.prev_ok_sample = copy.deepcopy(pixel_values) + self.require_update_prev = False + except Exception as e: + logger.warning(f"Fail to get sample of idx {idx}. The corrupted video will be replaced.") + print("\tError msg: {}".format(e), flush=True) + assert self.prev_ok_sample is not None + pixel_values = self.prev_ok_sample # unless the first sample is already not ok + self.require_update_prev = True + + if idx >= self.length: + raise IndexError # needed for checking the end of dataset iteration + + num_frames = len(pixel_values) + # pixel value: (f, h, w, 3) -> transforms -> (f 3 h' w') + if self.transform_backend == "al": + # NOTE:it's to ensure augment all frames in a video in the same way. + # ref: https://albumentations.ai/docs/examples/example_multi_target/ + + inputs = {"image": pixel_values[0]} + for i in range(num_frames - 1): + inputs[f"image{i}"] = pixel_values[i + 1] + + output = self.pixel_transforms(**inputs) + + pixel_values = np.stack(list(output.values()), axis=0) + # (t h w c) -> (c t h w) + pixel_values = np.transpose(pixel_values, (3, 0, 1, 2)) + else: + raise NotImplementedError + + if self.return_image: + pixel_values = pixel_values[1] + + pixel_values = (pixel_values / 127.5 - 1.0).astype(np.float32) + + return pixel_values + + +# TODO: parse in config dict +def check_sanity(x, save_fp="./tmp.gif"): + # reverse normalization and visulaize the transformed video + # (c, t, h, w) -> (t, h, w, c) + if len(x.shape) == 3: + x = np.expand_dims(x, axis=0) + x = np.transpose(x, (1, 2, 3, 0)) + + x = (x + 1.0) / 2.0 # -1,1 -> 0,1 + x = (x * 255).astype(np.uint8) + + imageio.mimsave(save_fp, x, duration=1 / 8.0, loop=1) + + +if __name__ == "__main__": + ds_config = dict( + data_folder="../videocomposer/datasets/webvid5", + random_crop=True, + flip=True, + ) + # test source dataset + ds = VideoDataset(**ds_config) + sample = ds.__getitem__(0) + print(sample.shape) + + check_sanity(sample) diff --git a/examples/magvit/videogvt/eval/__init__.py b/examples/magvit/videogvt/eval/__init__.py new file mode 100644 index 0000000000..c4dd906cbd --- /dev/null +++ b/examples/magvit/videogvt/eval/__init__.py @@ -0,0 +1,2 @@ +from .cal_psnr import calculate_psnr +from .cal_ssim import calculate_ssim diff --git a/examples/magvit/videogvt/eval/cal_psnr.py b/examples/magvit/videogvt/eval/cal_psnr.py new file mode 100644 index 0000000000..354c2c2cfd --- /dev/null +++ b/examples/magvit/videogvt/eval/cal_psnr.py @@ -0,0 +1,80 @@ +import numpy as np +from skimage.metrics import peak_signal_noise_ratio as cal_psnr +from tqdm import tqdm + + +def trans(x): + return x + + +def calculate_psnr(videos1, videos2): + print("calculate_psnr...") + + # videos [batch_size, timestamps, channel, h, w] + + assert videos1.shape == videos2.shape + + videos1 = trans(videos1) + videos2 = trans(videos2) + + psnr_results = [] + + for video_num in tqdm(range(videos1.shape[0])): + # get a video + # video [timestamps, channel, h, w] + video1 = videos1[video_num] + video2 = videos2[video_num] + + psnr_results_of_a_video = [] + for clip_timestamp in range(len(video1)): + # get a img + # img [timestamps[x], channel, h, w] + # img [channel, h, w] numpy + + img1 = video1[clip_timestamp] + img2 = video2[clip_timestamp] + + # calculate psnr of a video + psnr_results_of_a_video.append(cal_psnr(img1, img2)) + + psnr_results.append(psnr_results_of_a_video) + + psnr_results = np.array(psnr_results) # [batch_size, num_frames] + psnr = {} + psnr_std = {} + + for clip_timestamp in range(len(video1)): + psnr[clip_timestamp] = np.mean(psnr_results[:, clip_timestamp]) + psnr_std[clip_timestamp] = np.std(psnr_results[:, clip_timestamp]) + + result = { + "value": psnr, + "value_std": psnr_std, + "video_setting": video1.shape, + "video_setting_name": "time, channel, heigth, width", + } + + return result + + +# test code / using example + + +def main(): + from mindspore import ops + + NUMBER_OF_VIDEOS = 8 + VIDEO_LENGTH = 50 + CHANNEL = 3 + SIZE = 64 + videos1 = ops.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False) + videos2 = ops.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False) + + import json + + result = calculate_psnr(videos1, videos2) + print(json.dumps(result, indent=4)) + + +if __name__ == "__main__": + main() diff --git a/examples/magvit/videogvt/eval/cal_ssim.py b/examples/magvit/videogvt/eval/cal_ssim.py new file mode 100644 index 0000000000..4ffcaac878 --- /dev/null +++ b/examples/magvit/videogvt/eval/cal_ssim.py @@ -0,0 +1,100 @@ +import numpy as np +from skimage.metrics import structural_similarity as ssim +from tqdm import tqdm + + +def calculate_ssim_function(img1, img2): + # [0,1] + # ssim is the only metric extremely sensitive to gray being compared to b/w + if not img1.shape == img2.shape: + raise ValueError("Input images must have the same dimensions.") + if img1.ndim == 2: + return ssim(img1, img2) + elif img1.ndim == 3: + if img1.shape[0] == 3: + ssims = [] + for i in range(3): + ssims.append(ssim(img1[i], img2[i])) + return np.array(ssims).mean() + elif img1.shape[0] == 1: + return ssim(np.squeeze(img1), np.squeeze(img2)) + else: + raise ValueError("Wrong input image dimensions.") + + +def trans(x): + return x + + +def calculate_ssim(videos1, videos2): + print("calculate_ssim...") + + # videos [batch_size, timestamps, channel, h, w] + + assert videos1.shape == videos2.shape + + videos1 = trans(videos1) + videos2 = trans(videos2) + + ssim_results = [] + + for video_num in tqdm(range(videos1.shape[0])): + # get a video + # video [timestamps, channel, h, w] + video1 = videos1[video_num] + video2 = videos2[video_num] + + ssim_results_of_a_video = [] + for clip_timestamp in range(len(video1)): + # get a img + # img [timestamps[x], channel, h, w] + # img [channel, h, w] numpy + + img1 = video1[clip_timestamp] + img2 = video2[clip_timestamp] + + # calculate ssim of a video + ssim_results_of_a_video.append(calculate_ssim_function(img1, img2)) + + ssim_results.append(ssim_results_of_a_video) + + ssim_results = np.array(ssim_results) + + ssim = {} + ssim_std = {} + + for clip_timestamp in range(len(video1)): + ssim[clip_timestamp] = np.mean(ssim_results[:, clip_timestamp]) + ssim_std[clip_timestamp] = np.std(ssim_results[:, clip_timestamp]) + + result = { + "value": ssim, + "value_std": ssim_std, + "video_setting": video1.shape, + "video_setting_name": "time, channel, heigth, width", + } + + return result + + +# test code / using example + + +def main(): + from mindspore import ops + + NUMBER_OF_VIDEOS = 8 + VIDEO_LENGTH = 50 + CHANNEL = 3 + SIZE = 64 + videos1 = ops.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False) + videos2 = ops.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False) + + import json + + result = calculate_ssim(videos1, videos2) + print(json.dumps(result, indent=4)) + + +if __name__ == "__main__": + main() diff --git a/examples/magvit/videogvt/models/quantization/__init__.py b/examples/magvit/videogvt/models/quantization/__init__.py new file mode 100644 index 0000000000..8fde5c5d39 --- /dev/null +++ b/examples/magvit/videogvt/models/quantization/__init__.py @@ -0,0 +1 @@ +from .lookup_free_quantization import LFQ, LFQ2d diff --git a/examples/magvit/videogvt/models/quantization/lookup_free_quantization.py b/examples/magvit/videogvt/models/quantization/lookup_free_quantization.py new file mode 100644 index 0000000000..a2a56f5728 --- /dev/null +++ b/examples/magvit/videogvt/models/quantization/lookup_free_quantization.py @@ -0,0 +1,331 @@ +""" +Lookup Free Quantization +Proposed in https://arxiv.org/abs/2310.05737 + +In the simplest setup, each dimension is quantized into {-1, 1}. +An entropy penalty is used to encourage utilization. +""" + +from math import ceil, log2 + +import mindspore as ms +from mindspore import nn, ops + + +def exists(v): + return v is not None + + +def default(*args): + for arg in args: + if exists(arg): + return arg() if callable(arg) else arg + return None + + +def log(t, eps=1e-5): + return t.clamp(min=eps).log() + + +def entropy(prob): + return (-prob * log(prob)).sum(axis=-1) + + +# cosine sim linear + + +class CosineSimLinear(nn.Cell): + def __init__( + self, + dim_in, + dim_out, + scale=1.0, + dtype=ms.float32, + ): + super().__init__() + self.scale = scale + self.weight = ms.Parameter(ops.randn((dim_in, dim_out), dtype=dtype)) + + def construct(self, x): + x = ops.L2Normalize(axis=-1, epsilon=1e-12)(x) + w = ops.L2Normalize(axis=0, epsilon=1e-12)(self.weight) + return (x @ w) * self.scale + + +# class + + +class LFQ(nn.Cell): + def __init__( + self, + config, + return_loss_breakdown=False, + is_training=False, + dtype=ms.float32, + ): + super(LFQ, self).__init__() + + dim = config.dim + codebook_size = config.codebook_size + entropy_loss_weight = config.entropy_loss_weight + commitment_loss_weight = config.commitment_loss_weight + diversity_gamma = config.diversity_gamma + num_codebooks = config.num_codebooks + keep_num_codebooks_dim = config.keep_num_codebooks_dim + codebook_scale = config.codebook_scale # for residual LFQ, codebook scaled down by 2x at each layer + frac_per_sample_entropy = ( + config.frac_per_sample_entropy + ) # make less than 1. to only use a random fraction of the probs for per sample entropy + inv_temperature = config.inv_temperature + soft_clamp_input_value = config.soft_clamp_input_value + cosine_sim_project_in = config.cosine_sim_project_in + cosine_sim_project_in_scale = config.cosine_sim_project_in_scale + + # some assert validations + + assert exists(dim) or exists(codebook_size), "either dim or codebook_size must be specified for LFQ" + assert ( + not exists(codebook_size) or log2(codebook_size).is_integer() + ), f"your codebook size must be a power of 2 for lookup free quantization (suggested {2 ** ceil(log2(codebook_size))})" + + codebook_size = default(codebook_size, lambda: 2**dim) + codebook_dim = int(log2(codebook_size)) + + codebook_dims = codebook_dim * num_codebooks + dim = default(dim, codebook_dims) + + if cosine_sim_project_in: + cosine_sim_project_in_scale = default(cosine_sim_project_in_scale, codebook_scale) + project_in_klass = CosineSimLinear(dim, codebook_dims, scale=cosine_sim_project_in_scale) + else: + project_in_klass = nn.Dense(dim, codebook_dims, dtype=dtype) + + has_projections = dim != codebook_dims + self.project_in = project_in_klass if has_projections else nn.Identity() + self.project_out = nn.Dense(codebook_dims, dim, dtype=dtype) if has_projections else nn.Identity() + self.has_projections = has_projections + + self.dim = dim + self.codebook_dim = codebook_dim + self.num_codebooks = num_codebooks + self.return_loss_breakdown = return_loss_breakdown + + keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1) + assert not (num_codebooks > 1 and not keep_num_codebooks_dim) + self.keep_num_codebooks_dim = keep_num_codebooks_dim + + # straight through activation + + self.activation = nn.Identity() + + # entropy aux loss related weights + + assert 0 < frac_per_sample_entropy <= 1.0 + self.frac_per_sample_entropy = frac_per_sample_entropy + + self.diversity_gamma = diversity_gamma + self.entropy_loss_weight = entropy_loss_weight + + # codebook scale + + self.codebook_scale = codebook_scale + + # commitment loss + + self.commitment_loss_weight = commitment_loss_weight + + # whether to soft clamp the input value from -value to value + + self.soft_clamp_input_value = soft_clamp_input_value + assert not exists(soft_clamp_input_value) or soft_clamp_input_value >= codebook_scale + + # for no auxiliary loss, during inference + + self.mask = ops.pow(2, ops.arange(codebook_dim - 1, -1, -1)) + + # temperature + self.inv_temperature = inv_temperature + + # codes + + all_codes = ops.arange(codebook_size) + bits = ((all_codes[..., None] & self.mask) != 0).float().astype(dtype) + self.codebook = self.bits_to_codes(bits) + + # training + self.is_training = is_training + + def bits_to_codes(self, bits): + return bits * self.codebook_scale * 2 - self.codebook_scale + + def dtype(self): + return self.codebook.dtype + + def indices_to_codes(self, indices, project_out=True): + if not self.keep_num_codebooks_dim: + # indices = rearrange(indices, '... -> ... 1') + indices = indices.unsqueeze(-1) + + # indices to codes, which are bits of either -1 or 1 + + bits = ((indices[..., None].int() & self.mask) != 0).float() + + codes = self.bits_to_codes(bits) + + # codes = rearrange(codes, '... c d -> ... (c d)') + b, h, w, c, d = codes.shape + codes = codes.reshape(b, h, w, c * d) + + # whether to project codes out to original dimensions + # if the input feature dimensions were not log2(codebook size) + + if project_out: + codes = self.project_out(codes) + + # rearrange codes back to original shape + + # codes = rearrange(codes, 'b ... d -> b d ...') + codes = codes.permute(0, 4, 1, 2, 3) + + return codes + + def _forward(self, x): + x_shape = x.shape + # x, ps = pack_one(x, 'b * d') + b = x.shape[0] + d = x.shape[-1] + x = x.reshape(b, -1, d) + + assert x.shape[-1] == self.dim, f"expected dimension of {self.dim} but received {x.shape[-1]}" + + x = self.project_in(x) + + # maybe soft clamp + + if exists(self.soft_clamp_input_value): + clamp_value = self.soft_clamp_input_value + x = (x / clamp_value).tanh() * clamp_value + + # split out number of codebooks + + # x = rearrange(x, 'b n (c d) -> b n c d', c = self.num_codebooks) + b, n, _ = x.shape + x = x.reshape(b, n, self.num_codebooks, -1) + + # quantize by eq 3. + + original_input = x + + codebook_value = ops.ones_like(x) * self.codebook_scale + quantized = ops.where(x > 0, codebook_value, -codebook_value) + + # use straight-through gradients (optionally with custom activation fn) if training + + if self.is_training: + x = self.activation(x) + x = x + ops.stop_gradient(quantized - x) + else: + x = quantized + + # calculate indices + indices = ops.sum((x > 0).int() * self.mask.int(), dim=-1) + + # entropy aux loss + + if self.is_training: + # the same as euclidean distance up to a constant + # distance = -2 * einsum('... i d, j d -> ... i j', original_input, self.codebook) + distance = -2 * ops.matmul(original_input, self.codebook.t()) + + prob = ops.softmax(-distance * self.inv_temperature, axis=-1) + + b, n, c, d = prob.shape + prob = prob.reshape(b * n, c, d) + + # whether to only use a fraction of probs, for reducing memory + + if self.frac_per_sample_entropy < 1.0: + num_tokens = prob.shape[0] + num_sampled_tokens = int(num_tokens * self.frac_per_sample_entropy) + rand_mask = ops.randn(num_tokens).argsort(dim=-1) < num_sampled_tokens + per_sample_probs = prob[rand_mask] + else: + per_sample_probs = prob + + # calculate per sample entropy + + per_sample_entropy = entropy(per_sample_probs).mean() + + # distribution over all available tokens in the batch + + avg_prob = ops.mean(per_sample_probs, axis=0) + codebook_entropy = entropy(avg_prob).mean() + + # 1. entropy will be nudged to be low for each code, to encourage the network to output confident predictions + # 2. codebook entropy will be nudged to be high, to encourage all codes to be uniformly used within the batch + + entropy_aux_loss = per_sample_entropy - self.diversity_gamma * codebook_entropy + else: + entropy_aux_loss = ms.Tensor(0.0) + per_sample_entropy = ms.Tensor(0.0) + codebook_entropy = ms.Tensor(0.0) + + # commit loss + + if self.is_training: + commit_loss = ops.mse_loss(original_input, ops.stop_gradient(quantized), reduction="mean") + commit_loss = commit_loss.mean() + else: + commit_loss = ms.Tensor(0.0) + + # complete aux loss + + aux_loss = entropy_aux_loss * self.entropy_loss_weight + commit_loss * self.commitment_loss_weight + + # merge back codebook dim + + # x = rearrange(x, 'b n c d -> b n (c d)') + b, n, c, d = x.shape + x = x.reshape(b, n, c * d) + + # project out to feature dimension if needed + + x = self.project_out(x) + + # reconstitute image or video dimensions + + x = x.reshape(*x_shape) + + return x, indices, aux_loss + + def construct( + self, + x, + ): + """ + einstein notation + b - batch + n - sequence (or flattened spatial dimensions) + d - feature dimension, which is also log2(codebook size) + c - number of codebook dim + """ + + # standardize image or video into (batch, seq, dimension) + # x = rearrange(x, 'b d ... -> b ... d') + x = x.permute(0, 2, 3, 4, 1) + + x, indices, aux_loss = self._forward(x=x) + + x = x.permute(0, 4, 1, 2, 3) + + indices = indices.squeeze(-1) + + return (x, indices, aux_loss) + + +class LFQ2d(LFQ): + def construct(self, x): + x = x.permute(0, 2, 3, 1) + x, indices, aux_loss = self._forward(x=x) + x = x.permute(0, 3, 1, 2) + return (x, indices, aux_loss) diff --git a/examples/magvit/videogvt/models/vqvae/__init__.py b/examples/magvit/videogvt/models/vqvae/__init__.py new file mode 100644 index 0000000000..edf29fedfb --- /dev/null +++ b/examples/magvit/videogvt/models/vqvae/__init__.py @@ -0,0 +1,32 @@ +import logging + +import mindspore as ms + +from .discriminator import StyleGANDiscriminator +from .vqvae import VQVAE_2D, VQVAE_3D + +logger = logging.getLogger(__name__) + + +def build_model(model_name, model_config, is_training=True, pretrained=None, dtype=ms.float32): + if model_name == "vqvae-2d": + model = VQVAE_2D( + model_config, + is_training=is_training, + dtype=dtype, + ) + elif model_name == "vqvae-3d": + model = VQVAE_3D( + model_config, + is_training=is_training, + dtype=dtype, + ) + else: + raise NotImplementedError(f"{model_name} is not implemented.") + + if pretrained is not None: + param_dict = ms.load_checkpoint(pretrained) + ms.load_param_into_net(model, param_dict) + logger.info(f"Loading vqvae from {pretrained}.") + + return model diff --git a/examples/magvit/videogvt/models/vqvae/discriminator.py b/examples/magvit/videogvt/models/vqvae/discriminator.py new file mode 100644 index 0000000000..3369e9ab47 --- /dev/null +++ b/examples/magvit/videogvt/models/vqvae/discriminator.py @@ -0,0 +1,214 @@ +"""3D StyleGAN discriminator.""" + +import math + +import ml_collections +import numpy as np + +import mindspore as ms +from mindspore import nn, ops + +from .model_utils import GroupNormExtend + + +def get_pad_layer(pad_type): + if pad_type in ["refl", "reflect"]: + PadLayer = nn.ReflectionPad3d + elif pad_type in ["repl", "replicate"]: + PadLayer = nn.ReplicationPad3d + elif pad_type == "zero": + PadLayer = nn.ConstantPad3d + else: + print("Pad type [%s] not recognized" % pad_type) + return PadLayer + + +class BlurPool3d(nn.Cell): + def __init__( + self, + channels, + pad_type="reflect", + filt_size=4, + stride=2, + pad_off=0, + dtype=ms.float32, + ): + super(BlurPool3d, self).__init__() + self.filt_size = filt_size + self.pad_off = pad_off + self.pad_sizes = [ + int(1.0 * (filt_size - 1) / 2), + int(np.ceil(1.0 * (filt_size - 1) / 2)), + int(1.0 * (filt_size - 1) / 2), + int(np.ceil(1.0 * (filt_size - 1) / 2)), + int(1.0 * (filt_size - 1) / 2), + int(np.ceil(1.0 * (filt_size - 1) / 2)), + ] + self.pad_sizes = tuple([pad_size + pad_off for pad_size in self.pad_sizes]) + self.stride = stride + self.off = int((self.stride - 1) / 2.0) + self.channels = channels + self.dtype = dtype + + if self.filt_size == 1: + a = np.array( + [ + 1.0, + ] + ) + elif self.filt_size == 2: + a = np.array([1.0, 1.0]) + elif self.filt_size == 3: + a = np.array([1.0, 2.0, 1.0]) + elif self.filt_size == 4: + a = np.array([1.0, 3.0, 3.0, 1.0]) + elif self.filt_size == 5: + a = np.array([1.0, 4.0, 6.0, 4.0, 1.0]) + elif self.filt_size == 6: + a = np.array([1.0, 5.0, 10.0, 10.0, 5.0, 1.0]) + elif self.filt_size == 7: + a = np.array([1.0, 6.0, 15.0, 20.0, 15.0, 6.0, 1.0]) + + filt = ms.Tensor( + np.repeat(np.expand_dims(a[:, None] * a[None, :], 0), self.filt_size, axis=0), + self.dtype, + ) + filt = filt / ops.sum(filt) + filt = filt.unsqueeze(0).unsqueeze(0) + filt = filt.repeat(self.channels, 0).repeat(self.channels, 1) + self.filt = ms.Parameter(filt, requires_grad=False) + self.pad = ( + get_pad_layer(pad_type)(self.pad_sizes) + if pad_type != "zero" + else get_pad_layer(pad_type)(self.pad_sizes, 0) + ) + + def construct(self, inp): + if self.filt_size == 1: + if self.pad_off == 0: + return inp[:, :, :: self.stride, :: self.stride, :: self.stride] + else: + return self.pad(inp)[:, :, :: self.stride, :: self.stride, :: self.stride] + else: + return ops.conv3d(self.pad(inp), self.filt, stride=self.stride, groups=1) + + +class ResBlockDown(nn.Cell): + """3D StyleGAN ResBlock for D.""" + + def __init__( + self, + in_channels, + out_channels=None, + dropout=0.1, + dtype=ms.float32, + ): + super(ResBlockDown, self).__init__() + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels + + self.conv1 = nn.Conv3d(self.in_channels, self.out_channels, (3, 3, 3)).to_float(dtype) + self.norm1 = GroupNormExtend( + num_groups=32, + num_channels=self.out_channels, + eps=1e-5, + affine=True, + dtype=dtype, + ) + self.activation1 = nn.LeakyReLU() + self.conv2 = nn.Conv3d(self.out_channels, self.out_channels, (3, 3, 3)).to_float(dtype) + self.norm2 = GroupNormExtend( + num_groups=32, + num_channels=self.out_channels, + eps=1e-5, + affine=True, + dtype=dtype, + ) + self.activation2 = nn.LeakyReLU() + # self.dropout = nn.Dropout(p=dropout) + + self.conv_shortcut = nn.Conv3d(self.in_channels, self.out_channels, (1, 1, 1), has_bias=False).to_float(dtype) + + self.blurpool1 = BlurPool3d(self.out_channels, pad_type="zero") + self.blurpool2 = BlurPool3d(self.in_channels, pad_type="zero") + + def construct(self, x): + h = x + h = self.conv1(h) + h = self.norm1(h) + h = self.activation1(h) + + # h = ops.AvgPool3D(strides=(2, 2, 2))(h) + h = self.blurpool1(h) + + h = self.conv2(h) + h = self.norm2(h) + h = self.activation2(h) + + # x = ops.AvgPool3D(strides=(2, 2, 2))(x) + x = self.blurpool2(x) + + x = self.conv_shortcut(x) + + out = (x + h) / ops.sqrt(ms.Tensor(2, ms.float32)) + return out + + +class StyleGANDiscriminator(nn.Cell): + """StyleGAN Discriminator.""" + + def __init__( + self, + config: ml_collections.ConfigDict, + height: int, + width: int, + depth: int, + dtype: ms.dtype = ms.float32, + ): + super().__init__() + self.config = config + self.in_channles = 3 + self.filters = self.config.filters + self.channel_multipliers = self.config.channel_multipliers + + self.conv_in = nn.Conv3d(self.in_channles, self.filters, kernel_size=(3, 3, 3)).to_float(dtype) + # self.activation1 = nn.LeakyReLU() + self.resnet_stack = nn.SequentialCell() + + num_blocks = len(self.channel_multipliers) + sampling_rate = math.pow(2, num_blocks) + for i in range(num_blocks): + filters = self.filters * self.channel_multipliers[i] + + if i == 0: + dim_in = self.filters + else: + dim_in = self.filters * self.channel_multipliers[i - 1] + + self.resnet_stack.append(ResBlockDown(dim_in, filters, dtype=dtype)) + + dim_out = self.filters * self.channel_multipliers[-1] + self.norm2 = GroupNormExtend(num_groups=32, num_channels=dim_out, eps=1e-5, affine=True, dtype=dtype) + self.conv_out = nn.Conv3d(dim_out, dim_out, (3, 3, 3)).to_float(dtype) + # self.activation2 = nn.LeakyReLU() + + dim_dense = int( + dim_out * max(1, height // sampling_rate) * max(1, width // sampling_rate) * max(1, depth // sampling_rate) + ) + + self.linear1 = nn.Dense(dim_dense, 512, dtype=dtype) + self.linear2 = nn.Dense(512, 1, dtype=dtype) + + def construct(self, x): + # x = self.norm(x) + x = self.conv_in(x) + x = ops.elu(x) + x = self.resnet_stack(x) + x = self.conv_out(x) + x = self.norm2(x) + x = ops.elu(x) + x = ops.reshape(x, (x.shape[0], -1)) + x = self.linear1(x) + x = ops.elu(x) + x = self.linear2(x) + return x diff --git a/examples/magvit/videogvt/models/vqvae/enc_dec_2dcnn.py b/examples/magvit/videogvt/models/vqvae/enc_dec_2dcnn.py new file mode 100644 index 0000000000..acf7a75412 --- /dev/null +++ b/examples/magvit/videogvt/models/vqvae/enc_dec_2dcnn.py @@ -0,0 +1,256 @@ +import mindspore as ms +from mindspore import nn + +from .model_utils import GroupNormExtend, get_activation_fn + + +class ResBlock(nn.Cell): + def __init__( + self, + in_channels, # SCH: added + filters, + conv_fn, + activation_fn=nn.SiLU, + use_conv_shortcut=False, + num_groups=32, + dtype=ms.float32, + ): + super().__init__() + self.in_channels = in_channels + self.filters = filters + self.activate = activation_fn() + self.use_conv_shortcut = use_conv_shortcut + + # SCH: MAGVIT uses GroupNorm by default + self.norm1 = GroupNormExtend(num_groups, in_channels, dtype=dtype) + self.conv1 = conv_fn(in_channels, self.filters, kernel_size=(3, 3), has_bias=False, dtype=dtype) + self.norm2 = GroupNormExtend(num_groups, self.filters, dtype=dtype) + self.conv2 = conv_fn(self.filters, self.filters, kernel_size=(3, 3), has_bias=False, dtype=dtype) + if in_channels != filters: + if self.use_conv_shortcut: + self.conv3 = conv_fn( + in_channels, + self.filters, + kernel_size=(3, 3), + has_bias=False, + dtype=dtype, + ) + else: + self.conv3 = conv_fn( + in_channels, + self.filters, + kernel_size=(1, 1), + has_bias=False, + dtype=dtype, + ) + + def construct(self, x): + residual = x + x = self.norm1(x) + x = self.activate(x) + x = self.conv1(x) + x = self.norm2(x) + x = self.activate(x) + x = self.conv2(x) + if self.in_channels != self.filters: # SCH: ResBlock X->Y + residual = self.conv3(residual) + return x + residual + + +class Encoder(nn.Cell): + """Encoder Blocks.""" + + def __init__( + self, + config, + dtype=ms.float32, + ): + super().__init__() + + self.filters = config.filters # 128 + self.num_res_blocks = config.num_enc_res_blocks + self.num_blocks = len(config.channel_multipliers) + self.channel_multipliers = config.channel_multipliers # (1, 2, 2, 4) + self.spatial_downsample = config.spatial_downsample + self.num_groups = config.num_groups + self.embedding_dim = config.embedding_dim # num channels for latent vector + + self.activation_fn = get_activation_fn(config.activation_fn) + self.activate = self.activation_fn() + self.conv_fn = nn.Conv2d + self.block_args = dict( + conv_fn=self.conv_fn, + activation_fn=self.activation_fn, + use_conv_shortcut=False, + num_groups=self.num_groups, + dtype=dtype, + ) + + # first layer conv + self.conv_in = self.conv_fn( + config.channels, + self.filters, + kernel_size=(3, 3), + has_bias=False, + dtype=dtype, + ) + + # ResBlocks and conv downsample + self.block_res_blocks = nn.CellList([]) + self.conv_blocks = nn.CellList([]) + + filters = self.filters + prev_filters = filters # record for in_channels + for i in range(self.num_blocks): + filters = self.filters * self.channel_multipliers[i] + block_items = nn.CellList([]) + for _ in range(self.num_res_blocks): + block_items.append(ResBlock(prev_filters, filters, **self.block_args)) + prev_filters = filters # update in_channels + self.block_res_blocks.append(block_items) + + if i < self.num_blocks - 1: + if self.spatial_downsample[i]: + s_stride = 2 + self.conv_blocks.append( + self.conv_fn( + prev_filters, + filters, + kernel_size=(3, 3), + stride=(s_stride, s_stride), + ) + ) + prev_filters = filters # update in_channels + else: + # if no t downsample, don't add since this does nothing for pipeline models + self.conv_blocks.append(nn.Identity()) # Identity + prev_filters = filters # update in_channels + + # last layer res block + self.res_blocks = nn.CellList([]) + for _ in range(self.num_res_blocks): + self.res_blocks.append(ResBlock(prev_filters, filters, **self.block_args)) + prev_filters = filters # update in_channels + + # MAGVIT uses Group Normalization + self.norm1 = GroupNormExtend(self.num_groups, prev_filters, dtype=dtype) + + self.conv2 = self.conv_fn( + prev_filters, + self.embedding_dim, + kernel_size=(1, 1), + pad_mode="same", + dtype=dtype, + ) + + def construct(self, x): + x = self.conv_in(x) + + for i in range(self.num_blocks): + for j in range(self.num_res_blocks): + x = self.block_res_blocks[i][j](x) + if i < self.num_blocks - 1: + x = self.conv_blocks[i](x) + for i in range(self.num_res_blocks): + x = self.res_blocks[i](x) + + x = self.norm1(x) + x = self.activate(x) + x = self.conv2(x) + return x + + +class Decoder(nn.Cell): + """Decoder Blocks.""" + + def __init__( + self, + config, + dtype=ms.float32, + ): + super().__init__() + self.filters = config.filters + self.in_out_channels = config.channels + self.num_res_blocks = config.num_dec_res_blocks + self.num_blocks = len(config.channel_multipliers) + self.channel_multipliers = config.channel_multipliers + self.spatial_downsample = config.spatial_downsample + self.num_groups = config.num_groups + self.embedding_dim = config.embedding_dim + self.s_stride = 2 + + self.activation_fn = get_activation_fn(config.activation_fn) + self.activate = self.activation_fn() + self.conv_fn = nn.Conv2d + self.block_args = dict( + conv_fn=self.conv_fn, + activation_fn=self.activation_fn, + use_conv_shortcut=False, + num_groups=self.num_groups, + dtype=dtype, + ) + + filters = self.filters * self.channel_multipliers[-1] + prev_filters = filters + + # last conv + self.conv1 = self.conv_fn(self.embedding_dim, filters, kernel_size=(3, 3), has_bias=True, dtype=dtype) + + # last layer res block + self.res_blocks = nn.CellList([]) + for _ in range(self.num_res_blocks): + self.res_blocks.append(ResBlock(filters, filters, **self.block_args)) + + # ResBlocks and conv upsample + self.block_res_blocks = nn.CellList([]) + self.num_blocks = len(self.channel_multipliers) + self.conv_blocks = nn.CellList([]) + # reverse to keep track of the in_channels, but append also in a reverse direction + for i in reversed(range(self.num_blocks)): + filters = self.filters * self.channel_multipliers[i] + # resblock handling + block_items = nn.CellList([]) + for _ in range(self.num_res_blocks): + block_items.append(ResBlock(prev_filters, filters, **self.block_args)) + prev_filters = filters # SCH: update in_channels + self.block_res_blocks.insert(0, block_items) # SCH: append in front + + # conv blocks with upsampling + if i > 0: + if self.spatial_downsample[i - 1]: + # SCH: T-Causal Conv 3x3x3, f -> (t_stride * 2 * 2) * f, depth to space t_stride x 2 x 2 + self.conv_blocks.insert( + 0, + self.conv_fn( + prev_filters, + prev_filters * self.s_stride * self.s_stride, + kernel_size=(3, 3), + dtype=dtype, + ), + ) + else: + self.conv_blocks.insert( + 0, + nn.Identity(), + ) + + self.norm1 = GroupNormExtend(self.num_groups, prev_filters, dtype=dtype) + + self.conv_out = self.conv_fn(filters, self.in_out_channels, 3, dtype=dtype) + + def construct(self, x): + x = self.conv1(x) + for i in range(self.num_res_blocks): + x = self.res_blocks[i](x) + for i in reversed(range(self.num_blocks)): + for j in range(self.num_res_blocks): + x = self.block_res_blocks[i][j](x) + if i > 0: + x = self.conv_blocks[i - 1](x) + b, c, h, w = x.shape + x = x.reshape(b, -1, h * self.s_stride, w * self.s_stride) + + x = self.norm1(x) + x = self.activate(x) + x = self.conv_out(x) + return x diff --git a/examples/magvit/videogvt/models/vqvae/enc_dec_3dcnn.py b/examples/magvit/videogvt/models/vqvae/enc_dec_3dcnn.py new file mode 100644 index 0000000000..2529cd8316 --- /dev/null +++ b/examples/magvit/videogvt/models/vqvae/enc_dec_3dcnn.py @@ -0,0 +1,280 @@ +import mindspore as ms +from mindspore import nn + +from .model_utils import CausalConv3d, GroupNormExtend, get_activation_fn + + +class ResBlock(nn.Cell): + def __init__( + self, + in_channels, # SCH: added + filters, + conv_fn, + activation_fn=nn.SiLU, + use_conv_shortcut=False, + num_groups=32, + dtype=ms.float32, + ): + super().__init__() + self.in_channels = in_channels + self.filters = filters + self.activate = activation_fn() + self.use_conv_shortcut = use_conv_shortcut + + # SCH: MAGVIT uses GroupNorm by default + self.norm1 = GroupNormExtend(num_groups, in_channels) + self.conv1 = conv_fn( + in_channels, + self.filters, + kernel_size=(3, 3, 3), + has_bias=False, + dtype=dtype, + ) + self.norm2 = GroupNormExtend(num_groups, self.filters) + self.conv2 = conv_fn( + self.filters, + self.filters, + kernel_size=(3, 3, 3), + has_bias=False, + dtype=dtype, + ) + if in_channels != filters: + if self.use_conv_shortcut: + self.conv3 = conv_fn( + in_channels, + self.filters, + kernel_size=(3, 3, 3), + has_bias=False, + dtype=dtype, + ) + else: + self.conv3 = conv_fn( + in_channels, + self.filters, + kernel_size=(1, 1, 1), + has_bias=False, + dtype=dtype, + ) + + def construct(self, x): + residual = x + x = self.norm1(x) + x = self.activate(x) + x = self.conv1(x) + x = self.norm2(x) + x = self.activate(x) + x = self.conv2(x) + if self.in_channels != self.filters: # SCH: ResBlock X->Y + residual = self.conv3(residual) + return x + residual + + +class Encoder(nn.Cell): + """Encoder Blocks. (magvit version)""" + + def __init__( + self, + config, + dtype=ms.float32, + ): + super().__init__() + self.filters = config.filters + self.num_res_blocks = config.num_enc_res_blocks + self.num_blocks = len(config.channel_multipliers) + self.channel_multipliers = config.channel_multipliers + self.temporal_downsample = config.temporal_downsample + self.spatial_downsample = config.spatial_downsample + self.num_groups = config.num_groups + self.embedding_dim = config.embedding_dim + + self.activation_fn = get_activation_fn(config.activation_fn) + self.activate = self.activation_fn() + self.conv_fn = CausalConv3d + self.block_args = dict( + conv_fn=self.conv_fn, + activation_fn=self.activation_fn, + use_conv_shortcut=False, + num_groups=self.num_groups, + dtype=dtype, + ) + + # first layer conv + self.conv_in = self.conv_fn( + config.channels, + self.filters, + kernel_size=(3, 3, 3), + has_bias=False, + dtype=dtype, + ) + + # ResBlocks and conv downsample + self.block_res_blocks = nn.CellList([]) + self.conv_blocks = nn.CellList([]) + + filters = self.filters + prev_filters = filters # record for in_channels + for i in range(self.num_blocks): + filters = self.filters * self.channel_multipliers[i] + block_items = nn.CellList([]) + for _ in range(self.num_res_blocks): + block_items.append(ResBlock(prev_filters, filters, **self.block_args)) + prev_filters = filters # update in_channels + self.block_res_blocks.append(block_items) + + if i < self.num_blocks - 1: + if self.spatial_downsample[i]: + t_stride = 2 if self.temporal_downsample[i] else 1 + s_stride = 2 + self.conv_blocks.append( + self.conv_fn( + prev_filters, + filters, + kernel_size=(3, 3, 3), + strides=(t_stride, s_stride, s_stride), + dtype=dtype, + ) + ) + else: + # if no t downsample, don't add since this does nothing for pipeline models + self.conv_blocks.append(nn.Identity()) # Identity + + prev_filters = filters # update in_channels + + # last layer res block + self.res_blocks = nn.CellList([]) + for _ in range(self.num_res_blocks): + self.res_blocks.append(ResBlock(prev_filters, filters, **self.block_args)) + prev_filters = filters # update in_channels + + # MAGVIT uses Group Normalization + self.norm1 = GroupNormExtend(self.num_groups, prev_filters) + + self.conv2 = self.conv_fn( + prev_filters, + self.embedding_dim, + kernel_size=(1, 1, 1), + pad_mode="same", + dtype=dtype, + ) + + def construct(self, x): + x = self.conv_in(x) + + for i in range(self.num_blocks): + for j in range(self.num_res_blocks): + x = self.block_res_blocks[i][j](x) + if i < self.num_blocks - 1: + x = self.conv_blocks[i](x) + for i in range(self.num_res_blocks): + x = self.res_blocks[i](x) + + x = self.norm1(x) + x = self.activate(x) + x = self.conv2(x) + return x + + +class Decoder(nn.Cell): + """Decoder Blocks. (magvit version)""" + + def __init__( + self, + config, + dtype=ms.float32, + ): + super().__init__() + self.filters = config.filters + self.num_res_blocks = config.num_dec_res_blocks + self.num_blocks = len(config.channel_multipliers) + self.channel_multipliers = config.channel_multipliers + self.temporal_downsample = config.temporal_downsample + self.spatial_downsample = config.spatial_downsample + self.num_groups = config.num_groups + self.embedding_dim = config.embedding_dim + self.s_stride = 2 + + self.activation_fn = get_activation_fn(config.activation_fn) + self.activate = self.activation_fn() + self.conv_fn = CausalConv3d + self.block_args = dict( + conv_fn=self.conv_fn, + activation_fn=self.activation_fn, + use_conv_shortcut=False, + num_groups=self.num_groups, + dtype=dtype, + ) + + filters = self.filters * self.channel_multipliers[-1] + prev_filters = filters + + # last conv + self.conv1 = self.conv_fn( + self.embedding_dim, + filters, + kernel_size=(3, 3, 3), + has_bias=True, + dtype=dtype, + ) + + # last layer res block + self.res_blocks = nn.CellList([]) + for _ in range(self.num_res_blocks): + self.res_blocks.append(ResBlock(filters, filters, **self.block_args)) + + # ResBlocks and conv upsample + self.block_res_blocks = nn.CellList([]) + self.num_blocks = len(self.channel_multipliers) + self.conv_blocks = nn.CellList([]) + + # reverse to keep track of the in_channels, but append also in a reverse direction + for i in reversed(range(self.num_blocks)): + filters = self.filters * self.channel_multipliers[i] + # resblock handling + block_items = nn.CellList([]) + for _ in range(self.num_res_blocks): + block_items.append(ResBlock(prev_filters, filters, **self.block_args)) + prev_filters = filters # SCH: update in_channels + self.block_res_blocks.insert(0, block_items) # SCH: append in front + + # conv blocks with upsampling + if i > 0: + # SCH: T-Causal Conv 3x3x3, f -> (t_stride * 2 * 2) * f, depth to space t_stride x 2 x 2 + if self.spatial_downsample[i - 1]: + t_stride = 2 if self.temporal_downsample[i - 1] else 1 + # SCH: T-Causal Conv 3x3x3, f -> (t_stride * 2 * 2) * f, depth to space t_stride x 2 x 2 + self.conv_blocks.insert( + 0, + self.conv_fn( + prev_filters, + prev_filters * t_stride * self.s_stride * self.s_stride, + kernel_size=(3, 3, 3), + dtype=dtype, + ), + ) + else: + self.conv_blocks.insert( + 0, + nn.Identity(), + ) + + self.norm1 = GroupNormExtend(self.num_groups, prev_filters) + + self.conv_out = self.conv_fn(filters, config.channels, 3, dtype=dtype) + + def construct(self, x): + x = self.conv1(x) + for k in range(self.num_res_blocks): + x = self.res_blocks[k](x) + for i in reversed(range(self.num_blocks)): + for j in range(self.num_res_blocks): + x = self.block_res_blocks[i][j](x) + if i > 0: + t_stride = 2 if self.temporal_downsample[i - 1] else 1 + x = self.conv_blocks[i - 1](x) + b, c, t, h, w = x.shape + x = x.reshape(b, -1, t * t_stride, h * self.s_stride, w * self.s_stride) + + x = self.norm1(x) + x = self.activate(x) + x = self.conv_out(x) + return x diff --git a/examples/magvit/videogvt/models/vqvae/lpips.py b/examples/magvit/videogvt/models/vqvae/lpips.py new file mode 100644 index 0000000000..560a75146c --- /dev/null +++ b/examples/magvit/videogvt/models/vqvae/lpips.py @@ -0,0 +1,146 @@ +import logging +import os + +import mindcv + +import mindspore as ms +import mindspore.nn as nn +import mindspore.ops as ops + +_logger = logging.getLogger(__name__) + + +class LPIPS(nn.Cell): + # Learned perceptual metric + def __init__(self, use_dropout=True): + super().__init__() + self.scaling_layer = ScalingLayer() + self.chns = [64, 128, 256, 512, 512] # vgg16 features + self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) + self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) + self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) + self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) + self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) + # load NetLin metric layers + + # create vision backbone and load pretrained weights + self.net = vgg16(pretrained=True, requires_grad=False) + + self.set_train(False) + for param in self.trainable_params(): + param.requires_grad = False + + def load_from_pretrained(self, ckpt_path): + # TODO: just load ms ckpt + if not os.path.exists(ckpt_path): + raise ValueError( + f"{ckpt_path} not exists. Please download it from https://download-mindspore.osinfra.cn/toolkits/mindone/autoencoders/lpips_vgg-426bf45c.ckpt" + ) + + state_dict = ms.load_checkpoint(ckpt_path) + m, u = ms.load_param_into_net(self, state_dict) + if len(m) > 0: + print("missing keys:") + print(m) + if len(u) > 0: + print("unexpected keys:") + print(u) + + _logger.info("loaded pretrained LPIPS loss from {}".format(ckpt_path)) + + def construct(self, input, target): + in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) + outs0, outs1 = self.net(in0_input), self.net(in1_input) + val = 0 # ms.Tensor(0, dtype=input.dtype) + lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] + for kk in range(len(self.chns)): + diff = (normalize_tensor(outs0[kk]) - normalize_tensor(outs1[kk])) ** 2 + # res += spatial_average(lins[kk](diff), keepdim=True) + # lin_layer = lins[kk] + val += ops.mean(lins[kk](diff), axis=[2, 3], keep_dims=True) + return val + + +class ScalingLayer(nn.Cell): + def __init__(self): + super(ScalingLayer, self).__init__() + self.shift = ms.Tensor([-0.030, -0.088, -0.188])[None, :, None, None] + self.scale = ms.Tensor([0.458, 0.448, 0.450])[None, :, None, None] + + def construct(self, inp): + return (inp - self.shift) / self.scale + + +class NetLinLayer(nn.Cell): + """A single linear layer which does a 1x1 conv""" + + def __init__(self, chn_in, chn_out=1, use_dropout=False, dtype=ms.float32): + super(NetLinLayer, self).__init__() + # TODO: can parse dtype=dtype in ms2.3 + layers = ( + [ + nn.Dropout(p=0.5).to_float(dtype), + ] + if (use_dropout) + else [] + ) + layers += [ + nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, has_bias=False).to_float(dtype), + ] + self.model = nn.SequentialCell(layers) + + def construct(self, x): + return self.model(x) + + +class vgg16(nn.Cell): + def __init__(self, requires_grad=False, pretrained=True): + super(vgg16, self).__init__() + # FIXME: add bias in vgg. use the same model weights in PT. + model = mindcv.create_model("vgg16", pretrained=pretrained) + model.set_train(False) + vgg_pretrained_features = model.features + self.slice1 = nn.SequentialCell() + self.slice2 = nn.SequentialCell() + self.slice3 = nn.SequentialCell() + self.slice4 = nn.SequentialCell() + self.slice5 = nn.SequentialCell() + self.N_slices = 5 + for x in range(4): + self.slice1.append(vgg_pretrained_features[x]) + for x in range(4, 9): + self.slice2.append(vgg_pretrained_features[x]) + for x in range(9, 16): + self.slice3.append(vgg_pretrained_features[x]) + for x in range(16, 23): + self.slice4.append(vgg_pretrained_features[x]) + for x in range(23, 30): + self.slice5.append(vgg_pretrained_features[x]) + if not requires_grad: + for param in self.trainable_params(): + param.requires_grad = False + for param in model.trainable_params(): + param.requires_grad = False + + def construct(self, X): + h = self.slice1(X) + h_relu1_2 = h + h = self.slice2(h) + h_relu2_2 = h + h = self.slice3(h) + h_relu3_3 = h + h = self.slice4(h) + h_relu4_3 = h + h = self.slice5(h) + h_relu5_3 = h + out = (h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) + return out + + +def normalize_tensor(x, eps=1e-10): + norm_factor = ops.sqrt((x**2).sum(1, keepdims=True)) + return x / (norm_factor + eps) + + +def spatial_average(x, keepdim=True): + return x.mean([2, 3], keep_dims=keepdim) diff --git a/examples/magvit/videogvt/models/vqvae/model_utils.py b/examples/magvit/videogvt/models/vqvae/model_utils.py new file mode 100644 index 0000000000..1df14d8b58 --- /dev/null +++ b/examples/magvit/videogvt/models/vqvae/model_utils.py @@ -0,0 +1,249 @@ +from typing import Tuple, Union + +import mindspore as ms +from mindspore import nn, ops + + +def cast_tuple(t, length=1): + return t if isinstance(t, tuple) else ((t,) * length) + + +def divisible_by(num, den): + return (num % den) == 0 + + +def is_odd(n): + return not divisible_by(n, 2) + + +def pad_at_dim(t, pad, dim=-1, value=0.0): + dims_from_right = (-dim - 1) if dim < 0 else (t.ndim - dim - 1) + zeros = [(0, 0)] * dims_from_right + pad_op = ops.Pad(tuple(zeros + [pad] + [(0, 0)] * 2)) + return pad_op(t) + + +def exists(v): + return v is not None + + +def get_activation_fn(activation): + if activation == "relu": + activation_fn = nn.ReLU + elif activation == "swish": + activation_fn = nn.SiLU + else: + raise NotImplementedError + return activation_fn + + +class GroupNormExtend(nn.GroupNorm): + # GroupNorm supporting tensors with more than 4 dim + def construct(self, x): + x_shape = x.shape + if x.ndim >= 5: + x = x.view(x_shape[0], x_shape[1], x_shape[2], -1) + y = super().construct(x) + return y.view(x_shape) + + +class CausalConv3d(nn.Cell): + """ + Temporal padding: Padding with the first frame, by repeating K_t-1 times. + Spatial padding: follow standard conv3d, determined by pad mode and padding + Ref: opensora plan + + Args: + kernel_size: order (T, H, W) + stride: order (T, H, W) + padding: int, controls the amount of spatial padding applied to the input on both sides + """ + + def __init__( + self, + chan_in, + chan_out, + kernel_size: Union[int, Tuple[int, int, int]], + strides=None, + pad_mode="valid", + dtype=ms.float32, + **kwargs, + ): + super().__init__() + + kernel_size = cast_tuple(kernel_size, 3) + time_kernel_size, height_kernel_size, width_kernel_size = kernel_size + + assert is_odd(height_kernel_size) and is_odd(width_kernel_size) + + dilation = kwargs.pop("dilation", 1) + stride = strides[0] if strides is not None else kwargs.pop("stride", 1) + + # pad temporal dimension by k-1, manually + time_pad = dilation * (time_kernel_size - 1) + (1 - stride) + height_pad = height_kernel_size // 2 + width_pad = width_kernel_size // 2 + + self.time_pad = time_pad + self.time_causal_padding = ( + (0, 0), + (0, 0), + (time_pad, 0), + (height_pad, height_pad), + (width_pad, width_pad), + ) + + stride = strides if strides is not None else (stride, 1, 1) + dilation = (dilation, 1, 1) + + self.conv = nn.Conv3d( + chan_in, + chan_out, + kernel_size, + stride=stride, + dilation=dilation, + pad_mode=pad_mode, + dtype=dtype, + **kwargs, + ).to_float(dtype) + + def construct(self, x): + # x: (bs, Cin, T, H, W ) + op_pad = ops.Pad(self.time_causal_padding) + x = op_pad(x) + x = self.conv(x) + return x + + +class TimeDownsample2x(nn.Cell): + def __init__( + self, + dim, + dim_out, + kernel_size=3, + stride=1, + dtype=ms.float32, + ): + super().__init__() + self.time_causal_padding = (kernel_size - 1, 0) + self.conv = nn.Conv1d(dim, dim_out, kernel_size, stride=stride, pad_mode="valid", dtype=dtype).to_float(dtype) + + def construct(self, x): + b, c, t, h, w = x.shape + x = x.permute(0, 3, 4, 1, 2) + x = x.reshape(-1, c, t) + + x = ops.pad(x, self.time_causal_padding) + x = self.conv(x) + + x = x.reshape(b, h, w, c, -1) + x = x.permute(0, 3, 4, 1, 2) + + return x + + +class SpatialDownsample2x(nn.Cell): + def __init__( + self, + chan_in, + chan_out, + kernel_size: Union[int, Tuple[int]] = (3, 3), + stride: Union[int, Tuple[int]] = (2, 2), + dtype=ms.float32, + ): + super().__init__() + kernel_size = cast_tuple(kernel_size, 2) + stride = cast_tuple(stride, 2) + self.chan_in = chan_in + self.chan_out = chan_out + self.kernel_size = kernel_size + + self.conv = nn.Conv2d( + self.chan_in, + self.chan_out, + self.kernel_size, + stride=stride, + dtype=dtype, + ).to_float(dtype) + + def construct(self, x): + # x shape: (b c t h w) + + b, c_in, t_in, h_in, w_in = x.shape + + x = ops.permute(x, (0, 2, 1, 3, 4)) + x = x.reshape(b * t_in, c_in, h_in, w_in) + + x = self.conv(x) + + _, c_out, h_out, w_out = x.shape + x = x.reshape(b, t_in, c_out, h_out, w_out) + x = ops.permute(x, (0, 2, 1, 3, 4)) + + return x + + +class TimeUpsample2x(nn.Cell): + def __init__( + self, + dim, + dim_out, + kernel_size=3, + dtype=ms.float32, + ): + super().__init__() + + self.conv = nn.Conv1d(dim, dim_out * 2, kernel_size, dtype=dtype).to_float(dtype) + self.activate = nn.SiLU() + + def construct(self, x): + b, c, t, h, w = x.shape + x = x.permute(0, 3, 4, 1, 2) + x = x.reshape(-1, c, t) + + x = self.conv(x) + + x = x.reshape(b, h, w, -1, t * 2) + x = ops.permute(x, (0, 3, 4, 1, 2)) + + x = self.activate(x) + + return x + + +class SpatialUpsample2x(nn.Cell): + def __init__( + self, + chan_in, + chan_out, + kernel_size: Union[int, Tuple[int]] = (3, 3), + dtype=ms.float32, + ): + super().__init__() + self.chan_in = chan_in + self.chan_out = chan_out + self.kernel_size = kernel_size + + self.conv = nn.Conv2d( + self.chan_in, + self.chan_out * 4, + self.kernel_size, + dtype=dtype, + ).to_float(dtype) + + self.activate = nn.SiLU() + + def construct(self, x): + b, c_in, t_in, h_in, w_in = x.shape + + x = ops.permute(x, (0, 2, 1, 3, 4)) + x = ops.reshape(x, (b * t_in, c_in, h_in, w_in)) + + x = self.conv(x) + + x = ops.reshape(x, (b, t_in, self.chan_out, h_in * 2, w_in * 2)) + x = ops.permute(x, (0, 2, 1, 3, 4)) + + x = self.activate(x) + + return x diff --git a/examples/magvit/videogvt/models/vqvae/net_with_loss.py b/examples/magvit/videogvt/models/vqvae/net_with_loss.py new file mode 100644 index 0000000000..1e2169ac88 --- /dev/null +++ b/examples/magvit/videogvt/models/vqvae/net_with_loss.py @@ -0,0 +1,213 @@ +import mindspore as ms +from mindspore import nn, ops + +from .lpips import LPIPS + + +def _rearrange_in(x): + b, c, t, h, w = x.shape + x = x.permute(0, 2, 1, 3, 4) + x = ops.reshape(x, (b * t, c, h, w)) + + return x + + +def lecam_reg(real_pred, fake_pred, ema_real_pred, ema_fake_pred): + """Lecam loss for data-efficient and stable GAN training. + + Described in https://arxiv.org/abs/2104.03310 + + Args: + real_pred: Prediction (scalar) for the real samples. + fake_pred: Prediction for the fake samples. + ema_real_pred: EMA prediction (scalar) for the real samples. + ema_fake_pred: EMA prediction for the fake samples. + + Returns: + Lecam regularization loss (scalar). + """ + lecam_loss = ops.mean(ops.pow(ops.relu(real_pred - ema_fake_pred), 2)) + lecam_loss += ops.mean(ops.pow(ops.relu(ema_real_pred - fake_pred), 2)) + return lecam_loss + + +class GeneratorWithLoss(nn.Cell): + def __init__( + self, + vqvae, + disc_start=50001, + disc_weight=0.1, + disc_factor=1.0, + perceptual_weight=0.1, + recons_weight=5.0, + lecam_weight=0.001, + discriminator=None, + is_video=True, + dtype=ms.float32, + **kwargs, + ): + super().__init__() + + # build perceptual models for loss compute + self.vqvae = vqvae + self.perceptual_loss = LPIPS() # freeze params inside + + self.l1 = nn.L1Loss(reduction="none") + + self.disc_start = disc_start + self.disc_weight = disc_weight + self.disc_factor = disc_factor + self.recons_weight = recons_weight + self.perceptual_weight = perceptual_weight + self.lecam_weight = lecam_weight + + self.discriminator = discriminator + if (self.discriminator is not None) and (self.disc_factor > 0.0): + self.has_disc = True + else: + self.has_disc = False + + self.dtype = dtype + self.is_video = is_video + + def loss_function( + self, + x, + recons, + cond=None, + ): + if self.is_video: + x_reshape = _rearrange_in(x) + recons_reshape = _rearrange_in(recons) + else: + x_reshape = x + recons_reshape = recons + + # 2.1 entropy loss and commitment loss + + # 2.2 reconstruction loss in pixels + rec_loss = self.l1(recons_reshape, x_reshape) * self.recons_weight + + # 2.3 perceptual loss + if self.perceptual_weight > 0: + p_loss = self.perceptual_loss(recons_reshape, x_reshape) + rec_loss = rec_loss + self.perceptual_weight * p_loss + + loss = rec_loss.mean() + + # 2.4 discriminator loss if enabled + if self.has_disc: + # calc gan loss + if cond is None: + logits_fake = self.discriminator(recons) + else: + logits_fake = self.discriminator(ops.concat((recons, cond), dim=1)) + g_loss = -ops.mean(logits_fake) + + # LeCAM regularization + logits_real = self.discriminator(x) + lecam_loss = lecam_reg( + logits_real, + logits_fake, + ms.Tensor(0.0, self.dtype), + ms.Tensor(0.0, self.dtype), + ) + g_loss += lecam_loss * self.lecam_weight + + d_weight = self.disc_weight + loss += d_weight * self.disc_factor * g_loss + + return loss + + # in graph mode, construct code will run in graph. TODO: in pynative mode, need to add ms.jit decorator + def construct( + self, + x: ms.Tensor, + cond=None, + ): + """ + x: input images or videos, images: (b c h w), videos: (b c t h w) + global_step: global training step + """ + + # 1. AE forward, get aux_loss (entropy + commitment loss) and recons + _, _, recons, aux_loss = self.vqvae(x) + + # For videos, treat them as independent frame images + # TODO: regularize on temporal consistency + # if x.ndim >= 5: + # x: b c t h w -> (b*t c h w), shape for image perceptual loss + + # 2. compuate loss + loss = self.loss_function(x, recons, cond) + loss += aux_loss + + return loss + + +class DiscriminatorWithLoss(nn.Cell): + """ + Training logic: + For training step i, input data x: + 1. AE generator takes input x, feedforward to get posterior/latent and reconstructed data, and compute ae loss + 2. AE optimizer updates AE trainable params + 3. D takes the same input x, feed x to AE again **again** to get + the new posterior and reconstructions (since AE params has updated), feed x and recons to D, and compute D loss + 4. D optimizer updates D trainable params + --> Go to next training step + Ref: sd-vae training + """ + + def __init__( + self, + vqvae, + discriminator, + disc_start=50001, + disc_factor=1.0, + disc_loss="hinge", + ): + super().__init__() + self.vqvae = vqvae + self.discriminator = discriminator + self.disc_start = disc_start + self.disc_factor = disc_factor + + assert disc_loss in ["hinge", "vanilla"] + if disc_loss == "hinge": + self.disc_loss = self.hinge_loss + else: + self.softplus = ops.Softplus() + self.disc_loss = self.vanilla_d_loss + + def hinge_loss(self, logits_real, logits_fake): + loss_real = ops.mean(ops.relu(1.0 - logits_real)) + loss_fake = ops.mean(ops.relu(1.0 + logits_fake)) + d_loss = 0.5 * (loss_real + loss_fake) + return d_loss + + def vanilla_d_loss(self, logits_real, logits_fake): + d_loss = 0.5 * (ops.mean(self.softplus(-logits_real)) + ops.mean(self.softplus(logits_fake))) + return d_loss + + def construct(self, x: ms.Tensor): + """ + Second pass + Args: + x: input image/video, (bs c h w) + weights: sample weights + """ + + # 1. AE forward, get posterior (mean, logvar) and recons + _, _, recons, _ = self.vqvae(x) + + # 2. Disc forward to get class prediction on real input and reconstrucions + + logits_real = self.discriminator(x) + logits_fake = self.discriminator(recons) + + # logits_real = self.discriminator(ops.concat((x, cond), dim=1)) + # logits_fake = self.discriminator(ops.concat((recons, cond), dim=1)) + + d_loss = self.disc_factor * self.disc_loss(logits_real, logits_fake) + + return d_loss diff --git a/examples/magvit/videogvt/models/vqvae/vqvae.py b/examples/magvit/videogvt/models/vqvae/vqvae.py new file mode 100644 index 0000000000..51763d3fec --- /dev/null +++ b/examples/magvit/videogvt/models/vqvae/vqvae.py @@ -0,0 +1,145 @@ +from videogvt.models.quantization import LFQ, LFQ2d +from videogvt.models.vqvae import enc_dec_2dcnn, enc_dec_3dcnn + +import mindspore as ms +from mindspore import nn + +from .model_utils import CausalConv3d, pad_at_dim + + +class VQVAE_3D(nn.Cell): + def __init__( + self, + config, + is_training=True, + dtype=ms.float32, + ): + super().__init__() + + self.config = config.vqvae + self.dtype = dtype + self.encoder = enc_dec_3dcnn.Encoder(config=self.config, dtype=self.dtype) + self.decoder = enc_dec_3dcnn.Decoder(config=self.config, dtype=self.dtype) + self.quant_conv = CausalConv3d(self.config.embedding_dim, self.config.embedding_dim, 1, dtype=dtype) + self.post_quant_conv = CausalConv3d(self.config.embedding_dim, self.config.embedding_dim, 1, dtype=dtype) + self.quantizer = LFQ(config=config.lfq, is_training=is_training, dtype=self.dtype) + + self.time_downsample_factor = 2 ** sum(self.config.temporal_downsample) + self.patch_size = (self.time_downsample_factor, 1, 1) + self.out_channels = self.config.channels + self.num_frames = self.config.num_frames + + self.dtype = dtype + + def encode(self, x): + time_padding = ( + 0 + if (x.shape[2] % self.time_downsample_factor == 0) + else self.time_downsample_factor - x.shape[2] % self.time_downsample_factor + ) + x = pad_at_dim(x, (time_padding, 0), dim=2) + encoded_feature = self.encoder(x) + z_e = self.quant_conv(encoded_feature).to(x.dtype) + return z_e + + def decode(self, z): + time_padding = ( + 0 + if (self.num_frames % self.time_downsample_factor == 0) + else self.time_downsample_factor - self.num_frames % self.time_downsample_factor + ) + z = self.post_quant_conv(z) + x = self.decoder(z) + x = x[:, :, time_padding:] + return x + + def _forward(self, x): + # encode + z_e = self.encode(x) + + # quantization + embed_dtype = z_e.dtype + z_e = z_e.astype(self.dtype) + z_q, indices, aux_loss = self.quantizer(z_e) + + # decode + z_q = z_q.astype(embed_dtype) + recon_video = self.decode(z_q) + return recon_video + + def construct(self, x): + # encode + z_e = self.encode(x) + + # quantization + embed_dtype = z_e.dtype + z_e = z_e.astype(self.dtype) + z_q, indices, aux_loss = self.quantizer(z_e) + + # decode + z_q = z_q.astype(embed_dtype) + recon_video = self.decode(z_q) + + return z_e, z_q, recon_video, aux_loss + + +class VQVAE_2D(nn.Cell): + def __init__( + self, + config, + is_training=True, + dtype=ms.float32, + ): + super().__init__() + + self.config = config.vqvae + + self.space_downsample_factor = 2 ** sum(self.config.spatial_downsample) + self.patch_size = (self.space_downsample_factor, 1, 1) + self.out_channels = self.config.channels + + # NOTE: following MAGVIT, conv in bias=False in encoder first conv + self.encoder = enc_dec_2dcnn.Encoder(self.config, dtype=dtype) + self.decoder = enc_dec_2dcnn.Decoder(self.config, dtype=dtype) + self.quant_conv = nn.Conv2d(self.config.embedding_dim, self.config.embedding_dim, 1, dtype=dtype) + self.post_quant_conv = nn.Conv2d(self.config.embedding_dim, self.config.embedding_dim, 1, dtype=dtype) + + self.quantizer = LFQ2d( + config=config.lfq, + is_training=is_training, + dtype=dtype, + ) + + def encode(self, x): + encoded_feature = self.encoder(x) + z_e = self.quant_conv(encoded_feature).to(x.dtype) + return z_e + + def decode(self, z): + z = self.post_quant_conv(z) + x = self.decoder(z) + return x + + def _forward(self, x): + # encode + z_e = self.encode(x) + + # quantization + z_q, _, _ = self.quantizer(z_e) + + # decode + recon_video = self.decode(z_q) + + return recon_video + + def construct(self, x): + # encode + z_e = self.encode(x) + + # quantization + z_q, indices, aux_loss = self.quantizer(z_e) + + # decode + recon_video = self.decode(z_q) + + return z_e, z_q, recon_video, aux_loss