diff --git a/examples/opensora_hpcai/README.md b/examples/opensora_hpcai/README.md index 6b4bfd1bef..6c33cd4716 100644 --- a/examples/opensora_hpcai/README.md +++ b/examples/opensora_hpcai/README.md @@ -148,6 +148,7 @@ Your contributions are welcome. * [Data Processing](#data-processing) * [Training](#training) * [Evaluation](#evaluation) +* [VAE Training & Evaluation](#vae-training--evaluation) * [Contribution](#contribution) * [Acknowledgement](#acknowledgement) @@ -284,6 +285,7 @@ parameters is 724M. More information about training can be found in HPC-AI Tech' + ## Inference ### Open-Sora 1.2 and 1.1 Command Line Inference @@ -759,7 +761,80 @@ Here are some generation results after fine-tuning STDiT on a subset of WebVid d #### Quality Evaluation For quality evaluation, please refer to the original HPC-AI Tech [evaluation doc](https://github.com/hpcaitech/Open-Sora/blob/main/eval/README.md) for video generation quality evaluation. - + +## VAE Training & Evaluation + +A 3D-VAE pipeline consisting of a spatial VAE followed by a temporal VAE is trained in OpenSora v1.1. For more details, refer to [VAE Documentation](https://github.com/hpcaitech/Open-Sora/blob/main/docs/vae.md). + +### Prepare Pretrained Weights + +- Download pretained VAE-2D checkpoint from [PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers](https://huggingface.co/PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers/tree/main/vae) if you aim to train VAE-3D from spatial VAE initialization. + + Convert to ms checkpoint: + ``` + python tools/convert_vae1.2.py --src /path/to/pixart_sigma_sdxlvae_T5_diffusers/vae/diffusion_pytorch_model.safetensors --target models/sdxl_vae.ckpt --from_vae2d + ``` + +- Downalod pretrained VAE-3D checkpoint from [hpcai-tech/OpenSora-VAE-v1.2](https://huggingface.co/hpcai-tech/OpenSora-VAE-v1.2/tree/main) if you aim to train VAEA-3D from the VAE-3D model pre-trained with 3 stages. + + Convert to ms checkpoint: + ``` + python tools/convert_vae1.2.py --src /path/OpenSora-VAE-v1.2/models.safetensors --target models/OpenSora-VAE-v1.2/sdxl_vae.ckpt + ``` + +- Download lpips mindspore checkpoint from [here](https://download-mindspore.osinfra.cn/toolkits/mindone/autoencoders/lpips_vgg-426bf45c.ckpt) and put it under 'models/' + + +### Data Preprocess +Before VAE-3D training, we need to prepare a csv annotation file for the training videos. The csv file list the path to each video related to the root `video_folder`. An example is +``` +video +dance/vid001.mp4 +dance/vid002.mp4 +... +``` + +Taking UCF-101 for example, please download the [UCF-101](https://www.crcv.ucf.edu/data/UCF101.php) dataset and extract it to `datasets/UCF-101` folder. You can generate the csv annotation by running `python tools/annotate_vae_ucf101.py`. It will result in two csv files, `datasets/ucf101_train.csv` and `datasets/ucf101_test.csv`, for training and testing respectively. + + +### Training +```bash +# stage 1 training, 8 NPUs +msrun --worker_num=8 --local_work_num=8 \ +python scripts/train_vae.py --config configs/vae/train/stage1.yaml --use_parallel=True --csv_path datasets/ucf101_train.csv --video_folder datasets/UCF-101 + +# stage 2 training, 8 NPUs +msrun --worker_num=8 --local_work_num=8 \ +python scripts/train_vae.py --config configs/vae/train/stage2.yaml --use_parallel=True --csv_path datasets/ucf101_train.csv --video_folder datasets/UCF-101 + +# stage 3 training, 8 NPUs +msrun --worker_num=8 --local_work_num=8 \ +python scripts/train_vae.py --config configs/vae/train/stage3.yaml --use_parallel=True --csv_path datasets/ucf101_train.csv --video_folder datasets/UCF-101 +``` + +You can change the `csv_path` and `video_folder` to train on your own data. + +### Performance Evaluation +To evaluate the VAE performance, you need to run VAE inference first to generate the videos, then calculate scores on the generated videos: + +```bash +# video generation and evaluation +python scripts/inference_vae.py --ckpt_path /path/to/you_vae_ckpt --image_size 256 --num_frames=17 --csv_path datasets/ucf101_test.csv --video_folder datasets/UCF-101 +``` + +You can change the `csv_path` and `video_folder` to evaluate on your own data. + +Here, we report the training performance and evaluation results on the UCF-101 dataset. + +| Model | Context | jit_level | Precision | BS | NPUs | Resolution(framesxHxW) | Train T. (s/step) | PSNR | SSIM | +|:------------|:-------------|:--------|:---------:|:--:|:----:|:----------------------:|:-----------------:|:-----------------:|:-----------------:| +| VAE-3D | D910\*-[MS2.3.1](https://www.mindspore.cn/install) | O0 | BF16 | 1 | 8 | stage1-17x256x256 | 0.21 | n.a. | n.a. | +| VAE-3D | D910\*-[MS2.3.1](https://www.mindspore.cn/install) | O2 | BF16 | 1 | 1 | stage2-17x256x256 | 0.41 | n.a. | n.a. | +| VAE-3D | D910\*-[CANN C18(0705)](https://repo.mindspore.cn/ascend/ascend910/20240705/)-[MS2.3](https://www.mindspore.cn/install) | O1 | BF16 | 1 | 8 | stage3-17x256x256 | 0.93 | 29.29 | 0.88 | +> Context: {G:GPU, D:Ascend}{chip type}-{mindspore version}. + +Note that we train with mixed video ang image strategy i.e. `--mixed_strategy=mixed_video_image` for stage 3 instead of random number of frames (`mixed_video_random`). Random frame training will be supported in the future. + ## Training and Inference Using the FiT-Like Pipeline diff --git a/examples/opensora_hpcai/configs/vae/train/stage1.yaml b/examples/opensora_hpcai/configs/vae/train/stage1.yaml new file mode 100644 index 0000000000..e616eb7eec --- /dev/null +++ b/examples/opensora_hpcai/configs/vae/train/stage1.yaml @@ -0,0 +1,47 @@ +# model +model_type: "OpenSoraVAE_V1_2" +freeze_vae_2d: True +pretrained_model_path: "models/sdxl_vae.ckpt" + +# loss +perceptual_loss_weight: 0.1 +kl_loss_weight: 1.e-6 +use_real_rec_loss: False +use_z_rec_loss: True +use_image_identity_loss: True +mixed_strategy: "mixed_video_image" +mixed_image_ratio: 0.2 + +# data +dataset_name: "video" +csv_path: "../videocomposer/datasets/webvid5_copy.csv" +video_folder: "../videocomposer/datasets/webvid5" +frame_stride: 1 +num_frames: 17 +image_size: 256 + +micro_frame_size: null +micro_batch_size: null + +# training recipe +seed: 42 +use_discriminator: False +dtype: "bf16" +batch_size: 1 +clip_grad: True +max_grad_norm: 1.0 +start_learning_rate: 1.e-5 +scale_lr: False +use_recompute: False + +epochs: 2000 +ckpt_save_interval: 100 +init_loss_scale: 1. + +scheduler: "constant" +use_ema: False + +output_path: "outputs/causal_vae" + +# ms settting +jit_level: O1 diff --git a/examples/opensora_hpcai/configs/vae/train/stage2.yaml b/examples/opensora_hpcai/configs/vae/train/stage2.yaml new file mode 100644 index 0000000000..cde8c7f00a --- /dev/null +++ b/examples/opensora_hpcai/configs/vae/train/stage2.yaml @@ -0,0 +1,48 @@ +# model +model_type: "OpenSoraVAE_V1_2" +freeze_vae_2d: False +pretrained_model_path: "outputs/vae_stage1.ckpt" + +# loss +perceptual_loss_weight: 0.1 +kl_loss_weight: 1.e-6 +use_real_rec_loss: False +use_z_rec_loss: True +use_image_identity_loss: False +mixed_strategy: "mixed_video_image" +mixed_image_ratio: 0.2 + +# data +dataset_name: "video" +csv_path: "../videocomposer/datasets/webvid5_copy.csv" +video_folder: "../videocomposer/datasets/webvid5" +frame_stride: 1 +num_frames: 17 +image_size: 256 + +micro_frame_size: null +micro_batch_size: null +# flip: True + +# training recipe +seed: 42 +use_discriminator: False +dtype: "bf16" +batch_size: 1 +clip_grad: True +max_grad_norm: 1.0 +start_learning_rate: 1.e-5 +scale_lr: False +use_recompute: True + +epochs: 500 +ckpt_save_interval: 100 +init_loss_scale: 1. + +scheduler: "constant" +use_ema: False + +output_path: "outputs/vae_stage2" + +# ms settting +jit_level: O1 diff --git a/examples/opensora_hpcai/configs/vae/train/stage3.yaml b/examples/opensora_hpcai/configs/vae/train/stage3.yaml new file mode 100644 index 0000000000..012a6e6f9b --- /dev/null +++ b/examples/opensora_hpcai/configs/vae/train/stage3.yaml @@ -0,0 +1,49 @@ +# model +model_type: "OpenSoraVAE_V1_2" +freeze_vae_2d: False +pretrained_model_path: "outputs/vae_stage2.ckpt" + +# loss +perceptual_loss_weight: 0.1 +kl_loss_weight: 1.e-6 +use_real_rec_loss: True +use_z_rec_loss: False +use_image_identity_loss: False +mixed_strategy: "mixed_video_image" # TODO: use mixed_video_random after dynamic shape adaptation +mixed_image_ratio: 0.2 + +# data +dataset_name: "video" +csv_path: "../videocomposer/datasets/webvid5_copy.csv" +video_folder: "../videocomposer/datasets/webvid5" +frame_stride: 1 +num_frames: 17 # TODO: set 33 after dynamic shape adaptation and posterior concat fixed +image_size: 256 + +micro_frame_size: 17 +micro_batch_size: 4 +# flip: True + +# training recipe +seed: 42 +use_discriminator: False +dtype: "bf16" +batch_size: 1 +clip_grad: True +max_grad_norm: 1.0 +start_learning_rate: 1.e-5 +scale_lr: False +weight_decay: 0. +use_recompute: True + +epochs: 400 +ckpt_save_interval: 100 +init_loss_scale: 1. + +scheduler: "constant" +use_ema: False + +output_path: "outputs/vae_stage3" + +# ms settting +jit_level: O1 diff --git a/examples/opensora_hpcai/opensora/datasets/vae_dataset.py b/examples/opensora_hpcai/opensora/datasets/vae_dataset.py new file mode 100644 index 0000000000..e09310460a --- /dev/null +++ b/examples/opensora_hpcai/opensora/datasets/vae_dataset.py @@ -0,0 +1,358 @@ +import copy +import csv +import glob +import logging +import os +import random + +import albumentations +import cv2 +import imageio +import numpy as np +from decord import VideoReader + +import mindspore as ms + +logger = logging.getLogger() + + +def create_video_transforms( + size=384, crop_size=256, interpolation="bicubic", backend="al", random_crop=False, flip=False, num_frames=None +): + if backend == "al": + # expect rgb image in range 0-255, shape (h w c) + from albumentations import CenterCrop, HorizontalFlip, RandomCrop, SmallestMaxSize + + # NOTE: to ensure augment all frames in a video in the same way. + assert num_frames is not None, "num_frames must be parsed" + targets = {"image{}".format(i): "image" for i in range(num_frames)} + mapping = {"bilinear": cv2.INTER_LINEAR, "bicubic": cv2.INTER_CUBIC} + transforms = [ + SmallestMaxSize(max_size=size, interpolation=mapping[interpolation]), + CenterCrop(crop_size, crop_size) if not random_crop else RandomCrop(crop_size, crop_size), + ] + if flip: + transforms += [HorizontalFlip(p=0.5)] + + pixel_transforms = albumentations.Compose( + transforms, + additional_targets=targets, + ) + else: + raise NotImplementedError + + return pixel_transforms + + +def get_video_path_list(folder): + # TODO: find recursively + fmts = ["avi", "mp4", "gif"] + out = [] + for fmt in fmts: + out += glob.glob(os.path.join(folder, f"*.{fmt}")) + return sorted(out) + + +class VideoDataset: + def __init__( + self, + csv_path=None, + data_folder=None, + size=384, + crop_size=256, + random_crop=False, + flip=False, + sample_stride=4, + sample_n_frames=16, + return_image=False, + transform_backend="al", + video_column="video", + ): + """ + size: image resize size + crop_size: crop size after resize operation + """ + logger.info(f"loading annotations from {csv_path} ...") + + if csv_path is not None: + with open(csv_path, "r") as csvfile: + self.dataset = list(csv.DictReader(csvfile)) + self.read_from_csv = True + else: + self.dataset = get_video_path_list(data_folder) + self.read_from_csv = False + + self.length = len(self.dataset) + logger.info(f"Num data samples: {self.length}") + logger.info(f"sample_n_frames: {sample_n_frames}") + + self.data_folder = data_folder + self.sample_stride = sample_stride + self.sample_n_frames = sample_n_frames + self.return_image = return_image + + self.pixel_transforms = create_video_transforms( + size=size, + crop_size=crop_size, + random_crop=random_crop, + flip=flip, + num_frames=sample_n_frames, + ) + self.transform_backend = transform_backend + self.video_column = video_column + + # prepare replacement data + max_attempts = 100 + self.prev_ok_sample = self.get_replace_data(max_attempts) + self.require_update_prev = False + + def get_replace_data(self, max_attempts=100): + replace_data = None + attempts = min(max_attempts, self.length) + for idx in range(attempts): + try: + pixel_values = self.get_batch(idx) + replace_data = copy.deepcopy(pixel_values) + break + except Exception as e: + print("\tError msg: {}".format(e)) + + assert replace_data is not None, f"Fail to preload sample in {attempts} attempts." + + return replace_data + + def get_batch(self, idx): + # get video raw pixels (batch of frame) and its caption + if self.read_from_csv: + video_dict = self.dataset[idx] + video_fn = video_dict[list(video_dict.keys())[0]] + video_path = os.path.join(self.data_folder, video_fn) + else: + video_path = self.dataset[idx] + + video_reader = VideoReader(video_path) + + video_length = len(video_reader) + + if not self.return_image: + clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1) + start_idx = random.randint(0, video_length - clip_length) + batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int) + else: + batch_index = [random.randint(0, video_length - 1)] + + if video_path.endswith(".gif"): + pixel_values = video_reader[batch_index] # shape: (f, h, w, c) + else: + pixel_values = video_reader.get_batch(batch_index).asnumpy() # shape: (f, h, w, c) + + del video_reader + + return pixel_values + + def __len__(self): + return self.length + + def __getitem__(self, idx): + """ + Returns: + video: preprocessed video frames in shape (f, c, h, w), normalized to [-1, 1] + """ + try: + pixel_values = self.get_batch(idx) + if (self.prev_ok_sample is None) or (self.require_update_prev): + self.prev_ok_sample = copy.deepcopy(pixel_values) + self.require_update_prev = False + except Exception as e: + logger.warning(f"Fail to get sample of idx {idx}. The corrupted video will be replaced.") + print("\tError msg: {}".format(e), flush=True) + assert self.prev_ok_sample is not None + pixel_values = self.prev_ok_sample # unless the first sample is already not ok + self.require_update_prev = True + + if idx >= self.length: + raise IndexError # needed for checking the end of dataset iteration + + num_frames = len(pixel_values) + # pixel value: (f, h, w, 3) -> transforms -> (f 3 h' w') + if self.transform_backend == "al": + # NOTE:it's to ensure augment all frames in a video in the same way. + # ref: https://albumentations.ai/docs/examples/example_multi_target/ + + inputs = {"image": pixel_values[0]} + for i in range(num_frames - 1): + inputs[f"image{i}"] = pixel_values[i + 1] + + output = self.pixel_transforms(**inputs) + + pixel_values = np.stack(list(output.values()), axis=0) + # (t h w c) -> (c t h w) + pixel_values = np.transpose(pixel_values, (3, 0, 1, 2)) + else: + raise NotImplementedError + + if self.return_image: + pixel_values = pixel_values[1] + + pixel_values = (pixel_values / 127.5 - 1.0).astype(np.float32) + + return pixel_values + + +# TODO: parse in config dict +def check_sanity(x, save_fp="./tmp.gif"): + # reverse normalization and visulaize the transformed video + # (c, t, h, w) -> (t, h, w, c) + if len(x.shape) == 3: + x = np.expand_dims(x, axis=0) + x = np.transpose(x, (1, 2, 3, 0)) + + x = (x + 1.0) / 2.0 # -1,1 -> 0,1 + x = (x * 255).astype(np.uint8) + + imageio.mimsave(save_fp, x, duration=1 / 8.0, loop=1) + + +class BatchTransform: + def __init__(self, mixed_strategy, mixed_image_ratio=0.2): + self.mixed_strategy = mixed_strategy + self.mixed_image_ratio = mixed_image_ratio + + def __call__(self, x): + # x: (bs, c, t, h, w) + if self.mixed_strategy == "mixed_video_image": + if random.random() < self.mixed_image_ratio: + x = x[:, :, :1, :, :] + elif self.mixed_strategy == "mixed_video_random": + # TODO: somehow it's slow. consider do it with tensor in NetWithLoss + length = random.randint(1, x.shape[2]) + x = x[:, :, :length, :, :] + elif self.mixed_strategy == "image_only": + x = x[:, :, :1, :, :] + else: + raise ValueError + return x + + +def create_dataloader( + ds_config, + batch_size, + mixed_strategy=None, + mixed_image_ratio=0.0, + num_parallel_workers=12, + max_rowsize=32, + shuffle=True, + device_num=1, + rank_id=0, + drop_remainder=True, +): + """ + Args: + mixed_strategy: + None - all output batches are videoes [bs, c, T, h, w] + mixed_video_image - with prob of mixed_image_ratio, output batch are images [b, c, 1, h, w] + mixed_video_random - output batch has a random number of frames [bs, c, t, h, w], t is the same of samples in a batch + mixed_image_ratio: + ds_config, dataset config, args for ImageDataset or VideoDataset + ds_name: dataset name, image or video + """ + dataset = VideoDataset(**ds_config) + print("Total number of samples: ", len(dataset)) + + # Larger value leads to more memory consumption. Default: 16 + # prefetch_size = config.get("prefetch_size", 16) + # ms.dataset.config.set_prefetch_size(prefetch_size) + + dataloader = ms.dataset.GeneratorDataset( + source=dataset, + column_names=["video"], + num_shards=device_num, + shard_id=rank_id, + python_multiprocessing=True, + shuffle=shuffle, + num_parallel_workers=num_parallel_workers, + max_rowsize=max_rowsize, + ) + + dl = dataloader.batch( + batch_size, + drop_remainder=drop_remainder, + ) + + if mixed_strategy is not None: + batch_map_fn = BatchTransform(mixed_strategy, mixed_image_ratio) + dl = dl.map( + operations=batch_map_fn, + input_columns=["video"], + num_parallel_workers=1, + ) + + return dl + + +if __name__ == "__main__": + test = "dl" + if test == "dataset": + ds_config = dict( + data_folder="../videocomposer/datasets/webvid5", + random_crop=True, + flip=True, + ) + # test source dataset + ds = VideoDataset(**ds_config) + sample = ds.__getitem__(0) + print(sample.shape) + + check_sanity(sample) + else: + import math + import time + + from tqdm import tqdm + + ds_config = dict( + csv_path="../videocomposer/datasets/webvid5_copy.csv", + data_folder="../videocomposer/datasets/webvid5", + sample_n_frames=17, + size=128, + crop_size=128, + ) + + # test loader + dl = create_dataloader( + ds_config, + 4, + mixed_strategy="mixed_video_random", + mixed_image_ratio=0.2, + ) + + num_batches = dl.get_dataset_size() + # ms.set_context(mode=0) + print(num_batches) + + steps = 50 + iterator = dl.create_dict_iterator(100) # create 100 repeats + tot = 0 + + progress_bar = tqdm(range(steps)) + progress_bar.set_description("Steps") + + start = time.time() + for epoch in range(math.ceil(steps / num_batches)): + for i, batch in enumerate(iterator): + print("epoch", epoch, "step", i) + dur = time.time() - start + tot += dur + + if epoch * num_batches + i < 50: + for k in batch: + print(k, batch[k].shape, batch[k].dtype) # , batch[k].min(), batch[k].max()) + print(f"time cost: {dur * 1000} ms") + + progress_bar.update(1) + if i + 1 > steps: # in case the data size is too large + break + start = time.time() + + mean = tot / steps + print("Avg batch loading time: ", mean) diff --git a/examples/opensora_hpcai/opensora/models/vae/losses.py b/examples/opensora_hpcai/opensora/models/vae/losses.py new file mode 100644 index 0000000000..ef6d9b0d6a --- /dev/null +++ b/examples/opensora_hpcai/opensora/models/vae/losses.py @@ -0,0 +1,236 @@ +import mindspore as ms +from mindspore import nn, ops + +from .lpips import LPIPS + + +def _rearrange_in(x): + b, c, t, h, w = x.shape + x = x.permute(0, 2, 1, 3, 4) + x = ops.reshape(x, (b * t, c, h, w)) + + return x + + +class GeneratorWithLoss(nn.Cell): + def __init__( + self, + autoencoder, + kl_weight=1.0e-06, + perceptual_weight=1.0, + logvar_init=0.0, + use_real_rec_loss=False, + use_z_rec_loss=False, + use_image_identity_loss=False, + dtype=ms.float32, + ): + super().__init__() + + # build perceptual models for loss compute + self.autoencoder = autoencoder + # TODO: set dtype for LPIPS ? + self.perceptual_loss = LPIPS() # freeze params inside + + # self.l1 = nn.L1Loss(reduction="none") + # TODO: is self.logvar trainable? + self.logvar = ms.Parameter(ms.Tensor([logvar_init], dtype=ms.float32)) + + self.kl_weight = kl_weight + self.perceptual_weight = perceptual_weight + + self.use_real_rec_loss = use_real_rec_loss + self.use_z_rec_loss = use_z_rec_loss + self.use_image_identity_loss = use_image_identity_loss + + def kl(self, mean, logvar): + # cast to fp32 to avoid overflow in exp and sum ops. + mean = mean.astype(ms.float32) + logvar = logvar.astype(ms.float32) + + var = ops.exp(logvar) + kl_loss = 0.5 * ops.sum( + ops.pow(mean, 2) + var - 1.0 - logvar, + dim=[1, 2, 3], + ) + return kl_loss + + def vae_loss_fn( + self, x, recons, mean, logvar, nll_weights=None, no_perceptual=False, no_kl=False, pixelwise_mean=False + ): + """ + return: + nll_loss: weighted sum of pixel reconstruction loss and perceptual loss + weighted_nll_loss: weighted mean of nll_loss + weighted_kl_loss: KL divergence on posterior + """ + bs = x.shape[0] + # (b c t h w) -> (b*t c h w) + x = _rearrange_in(x) + recons = _rearrange_in(recons) + + # reconstruction loss in pixels + # FIXME: debugging: use pixelwise mean to reduce loss scale + if pixelwise_mean: + rec_loss = ((x - recons) ** 2).mean() + else: + rec_loss = ops.abs(x - recons) + + # perceptual loss + if (self.perceptual_weight > 0) and (not no_perceptual): + p_loss = self.perceptual_loss(x, recons) + rec_loss = rec_loss + self.perceptual_weight * p_loss + + nll_loss = rec_loss / ops.exp(self.logvar) + self.logvar + if nll_weights is not None: + weighted_nll_loss = nll_weights * nll_loss + weighted_nll_loss = weighted_nll_loss.sum() / bs + else: + weighted_nll_loss = nll_loss.sum() / bs + + # kl loss + # TODO: FIXME: it may not fit for graph mode training + if (self.kl_weight > 0) and (not no_kl): + kl_loss = self.kl(mean, logvar) + kl_loss = kl_loss.sum() / bs + weighted_kl_loss = self.kl_weight * kl_loss + else: + weighted_kl_loss = 0 + + return nll_loss, weighted_nll_loss, weighted_kl_loss + + def construct(self, x: ms.Tensor, global_step: ms.Tensor = -1, weights: ms.Tensor = None, cond=None): + """ + x: input images or videos, images: (b c 1 h w), videos: (b c t h w) + weights: sample weights + global_step: global training step + """ + print("D--: x shape: ", x.shape) + # 3d vae forward, get posterior (mean, logvar) and recons + # x -> VAE2d-Enc -> x_z -> TemporalVAE-Enc -> z ~ posterior -> TempVAE-Dec -> x_z_rec -> VAE2d-Dec -> x_rec + x_rec, x_z_rec, z, posterior_mean, posterior_logvar, x_z = self.autoencoder(x) + # FIXME: debugging + x_rec, x_z_rec, z, posterior_mean, posterior_logvar, x_z = ( + x_rec.to(ms.float32), + x_z_rec.to(ms.float32), + z.to(ms.float32), + posterior_mean.to(ms.float32), + posterior_logvar.to(ms.float32), + x_z.to(ms.float32), + ) + + frames = x.shape[2] + + # Loss compute + # 1. VAE 2d, video frames x reconstruction loss + # TODO: loss dtype setting + if self.use_real_rec_loss: + # x: (b 3 t h w) + _, weighted_nll_loss, weighted_kl_loss = self.vae_loss_fn( + x, x_rec, posterior_mean, posterior_logvar, no_perceptual=False + ) + loss = weighted_nll_loss + weighted_kl_loss + else: + loss = 0 + + # 2. temporal vae, spatial latent x_z reconstruction loss + if self.use_z_rec_loss: + # x_z: (b 4 t h//8 w//8) + # NOTE: since KL loss on posterior is the same as that in part 1. We can skip it. + _, weighted_nll_loss_z, _ = self.vae_loss_fn( + x_z, x_z_rec, posterior_mean, posterior_logvar, no_perceptual=True, no_kl=True + ) + loss += weighted_nll_loss_z + + # 3. identity regularization loss for pure image input + if self.use_image_identity_loss and frames == 1: + _, image_identity_loss, _ = self.vae_loss_fn( + x_z, z, posterior_mean, posterior_logvar, no_perceptual=True, no_kl=True + ) + loss += image_identity_loss + + return loss + + +# Discriminator is not used in opensora v1.2 +class DiscriminatorWithLoss(nn.Cell): + """ + Training logic: + For training step i, input data x: + 1. AE generator takes input x, feedforward to get posterior/latent and reconstructed data, and compute ae loss + 2. AE optimizer updates AE trainable params + 3. D takes the same input x, feed x to AE again **again** to get + the new posterior and reconstructions (since AE params has updated), feed x and recons to D, and compute D loss + 4. D optimizer updates D trainable params + --> Go to next training step + Ref: sd-vae training + """ + + def __init__( + self, + autoencoder, + discriminator, + disc_start=50001, + disc_factor=1.0, + disc_loss="hinge", + ): + super().__init__() + self.autoencoder = autoencoder + self.discriminator = discriminator + self.disc_start = disc_start + self.disc_factor = disc_factor + + assert disc_loss in ["hinge", "vanilla"] + if disc_loss == "hinge": + self.disc_loss = self.hinge_loss + else: + self.softplus = ops.Softplus() + self.disc_loss = self.vanilla_d_loss + + def hinge_loss(self, logits_real, logits_fake): + loss_real = ops.mean(ops.relu(1.0 - logits_real)) + loss_fake = ops.mean(ops.relu(1.0 + logits_fake)) + d_loss = 0.5 * (loss_real + loss_fake) + return d_loss + + def vanilla_d_loss(self, logits_real, logits_fake): + d_loss = 0.5 * (ops.mean(self.softplus(-logits_real)) + ops.mean(self.softplus(logits_fake))) + return d_loss + + def construct(self, x: ms.Tensor, global_step=-1, cond=None): + """ + Second pass + Args: + x: input image/video, (bs c h w) + weights: sample weights + """ + + # 1. AE forward, get posterior (mean, logvar) and recons + recons, mean, logvar = ops.stop_gradient(self.autoencoder(x)) + + if x.ndim >= 5: + # TODO: use 3D discriminator + # x: b c t h w -> (b*t c h w), shape for image perceptual loss + x = _rearrange_in(x) + recons = _rearrange_in(recons) + + # 2. Disc forward to get class prediction on real input and reconstrucions + if cond is None: + logits_real = self.discriminator(x) + logits_fake = self.discriminator(recons) + else: + logits_real = self.discriminator(ops.concat((x, cond), dim=1)) + logits_fake = self.discriminator(ops.concat((recons, cond), dim=1)) + + if global_step >= self.disc_start: + disc_factor = self.disc_factor + else: + disc_factor = 0.0 + + d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) + + # log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), + # "{}/logits_real".format(split): logits_real.detach().mean(), + # "{}/logits_fake".format(split): logits_fake.detach().mean() + # } + + return d_loss diff --git a/examples/opensora_hpcai/opensora/models/vae/lpips.py b/examples/opensora_hpcai/opensora/models/vae/lpips.py new file mode 100644 index 0000000000..ca1fbb4442 --- /dev/null +++ b/examples/opensora_hpcai/opensora/models/vae/lpips.py @@ -0,0 +1,139 @@ +import logging +import os + +import mindcv +from opensora.utils.load_models import load_from_pretrained + +import mindspore as ms +import mindspore.nn as nn +import mindspore.ops as ops + +_logger = logging.getLogger(__name__) + + +class LPIPS(nn.Cell): + # Learned perceptual metric + def __init__(self, use_dropout=True): + super().__init__() + self.scaling_layer = ScalingLayer() + self.chns = [64, 128, 256, 512, 512] # vgg16 features + self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) + self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) + self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) + self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) + self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) + # load NetLin metric layers + self.load_lpips() + + self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] + self.lins = nn.CellList(self.lins) + + # create vision backbone and load pretrained weights + self.net = vgg16(pretrained=True, requires_grad=False) + + self.set_train(False) + for param in self.trainable_params(): + param.requires_grad = False + + def load_lpips(self, ckpt_path="models/lpips_vgg-426bf45c.ckpt"): + if not os.path.exists(ckpt_path): + ckpt_path = "https://download-mindspore.osinfra.cn/toolkits/mindone/autoencoders/lpips_vgg-426bf45c.ckpt" + load_from_pretrained(self, ckpt_path) + + _logger.info("loaded pretrained LPIPS loss from {}".format(ckpt_path)) + + def construct(self, input, target): + in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) + outs0, outs1 = self.net(in0_input), self.net(in1_input) + val = 0 # ms.Tensor(0, dtype=input.dtype) + for kk in range(len(self.chns)): + diff = (normalize_tensor(outs0[kk]) - normalize_tensor(outs1[kk])) ** 2 + # res += spatial_average(lins[kk](diff), keepdim=True) + # lin_layer = lins[kk] + val += ops.mean(self.lins[kk](diff), axis=[2, 3], keep_dims=True) + return val + + +class ScalingLayer(nn.Cell): + def __init__(self): + super(ScalingLayer, self).__init__() + self.shift = ms.Tensor([-0.030, -0.088, -0.188])[None, :, None, None] + self.scale = ms.Tensor([0.458, 0.448, 0.450])[None, :, None, None] + + def construct(self, inp): + return (inp - self.shift) / self.scale + + +class NetLinLayer(nn.Cell): + """A single linear layer which does a 1x1 conv""" + + def __init__(self, chn_in, chn_out=1, use_dropout=False, dtype=ms.float32): + super(NetLinLayer, self).__init__() + # TODO: can parse dtype=dtype in ms2.3 + layers = ( + [ + nn.Dropout(p=0.5).to_float(dtype), + ] + if (use_dropout) + else [] + ) + layers += [ + nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, has_bias=False).to_float(dtype), + ] + self.model = nn.SequentialCell(layers) + + def construct(self, x): + return self.model(x) + + +class vgg16(nn.Cell): + def __init__(self, requires_grad=False, pretrained=True): + super(vgg16, self).__init__() + # FIXME: add bias in vgg. use the same model weights in PT. + model = mindcv.create_model("vgg16", pretrained=pretrained) + model.set_train(False) + vgg_pretrained_features = model.features + self.slice1 = nn.SequentialCell() + self.slice2 = nn.SequentialCell() + self.slice3 = nn.SequentialCell() + self.slice4 = nn.SequentialCell() + self.slice5 = nn.SequentialCell() + self.N_slices = 5 + for x in range(4): + self.slice1.append(vgg_pretrained_features[x]) + for x in range(4, 9): + self.slice2.append(vgg_pretrained_features[x]) + for x in range(9, 16): + self.slice3.append(vgg_pretrained_features[x]) + for x in range(16, 23): + self.slice4.append(vgg_pretrained_features[x]) + for x in range(23, 30): + self.slice5.append(vgg_pretrained_features[x]) + if not requires_grad: + for param in self.trainable_params(): + param.requires_grad = False + for param in model.trainable_params(): + param.requires_grad = False + + def construct(self, X): + h = self.slice1(X) + h_relu1_2 = h + h = self.slice2(h) + h_relu2_2 = h + h = self.slice3(h) + h_relu3_3 = h + h = self.slice4(h) + h_relu4_3 = h + h = self.slice5(h) + h_relu5_3 = h + out = (h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) + return out + + +def normalize_tensor(x, eps=1e-10): + norm_factor = ops.sqrt((x**2).sum(1, keepdims=True)) + return x / (norm_factor + eps) + + +def spatial_average(x, keepdim=True): + return x.mean([2, 3], keep_dims=keepdim) diff --git a/examples/opensora_hpcai/opensora/models/vae/vae.py b/examples/opensora_hpcai/opensora/models/vae/vae.py index f23e8c3804..50bfe48a73 100644 --- a/examples/opensora_hpcai/opensora/models/vae/vae.py +++ b/examples/opensora_hpcai/opensora/models/vae/vae.py @@ -4,7 +4,7 @@ from transformers import PretrainedConfig import mindspore as ms -from mindspore import mint, nn, ops +from mindspore import nn, ops from ..layers.operation_selector import get_split_op from .autoencoder_kl import AutoencoderKL as AutoencoderKL_SD @@ -51,7 +51,7 @@ def encode_with_moments_output(self, x): """For latent caching usage""" h = self.encoder(x) moments = self.quant_conv(h) - mean, logvar = mint.split(moments, moments.shape[1] // 2, 1) + mean, logvar = self.split(moments, moments.shape[1] // 2, 1) logvar = ops.clip_by_value(logvar, -30.0, 20.0) std = self.exp(0.5 * logvar) @@ -130,12 +130,6 @@ def encode(self, x): if self.micro_batch_size is None: x_out = self.module.encode(x) * self.scale_factor else: - """ - x_splits = mint.split(x, self.micro_batch_size, 0) - x_out = tuple((self.module.encode(x_bs) * self.scale_factor) for x_bs in x_splits) - x_out = ops.cat(x_out, axis=0) - """ - bs = self.micro_batch_size x_out = self.module.encode(x[:bs]) * self.scale_factor for i in range(bs, x.shape[0], bs): @@ -158,14 +152,8 @@ def decode(self, x, **kwargs): if self.micro_batch_size is None: x_out = self.module.decode(x / self.scale_factor) else: - """ - # can try after split op bug fixed - x_splits = mint.split(x, self.micro_batch_size, 0) - x_out = tuple(self.module.decode(x_bs / self.scale_factor) for x_bs in x_splits) - x_out = ops.cat(x_out, axis=0) - """ - mbs = self.micro_batch_size + x_out = self.module.decode(x[:mbs] / self.scale_factor) for i in range(mbs, x.shape[0], mbs): x_cur = self.module.decode(x[i : i + mbs] / self.scale_factor) @@ -243,6 +231,7 @@ def __init__(self, config: VideoAutoencoderPipelineConfig): self.cal_loss = config.cal_loss self.micro_frame_size = config.micro_frame_size self.micro_z_frame_size = self.temporal_vae.get_latent_size([config.micro_frame_size, None, None])[0] + print(f"micro_frame_size: {self.micro_frame_size}, micro_z_frame_size: {self.micro_z_frame_size}") if config.freeze_vae_2d: for param in self.spatial_vae.get_parameters(): @@ -310,15 +299,14 @@ def decode(self, z, num_frames=None): else: return x else: - # z: (b Z t//4 h w) mz = self.micro_z_frame_size - x_z_out = self.temporal_vae.decode(z[:, :, :mz], num_frames=min(self.micro_frame_size, num_frames)) + remain_frames = num_frames if self.micro_frame_size > num_frames else self.micro_frame_size + x_z_out = self.temporal_vae.decode(z[:, :, :mz], num_frames=remain_frames) num_frames -= self.micro_frame_size for i in range(mz, z.shape[2], mz): - x_z_cur = self.temporal_vae.decode( - z[:, :, i : i + mz], num_frames=min(self.micro_frame_size, num_frames) - ) + remain_frames = num_frames if self.micro_frame_size > num_frames else self.micro_frame_size + x_z_cur = self.temporal_vae.decode(z[:, :, i : i + mz], num_frames=remain_frames) x_z_out = ops.cat((x_z_out, x_z_cur), axis=2) num_frames -= self.micro_frame_size @@ -365,6 +353,7 @@ def OpenSoraVAE_V1_2( ckpt_path: path to the checkpoint of the overall model (vae2d + temporal vae) vae_2d_ckpt_path: path to the checkpoint of the vae 2d model. It will only be loaded when `ckpt_path` not provided. """ + if isinstance(micro_batch_size, int): if micro_batch_size <= 0: micro_batch_size = None diff --git a/examples/opensora_hpcai/opensora/utils/load_models.py b/examples/opensora_hpcai/opensora/utils/load_models.py index 19068b0b1b..19af59ab56 100644 --- a/examples/opensora_hpcai/opensora/utils/load_models.py +++ b/examples/opensora_hpcai/opensora/utils/load_models.py @@ -1,258 +1,69 @@ import logging import os +import re +from typing import Union + +from mindcv.utils.download import DownLoad import mindspore as ms +from mindspore import nn -from mindone.utils.config import instantiate_from_config from mindone.utils.params import load_param_into_net_with_filter logger = logging.getLogger() -def merge_lora_to_unet(unet, lora_ckpt_path, alpha=1.0): - """ - Merge lora weights to modules of UNet cell. Make sure SD checkpoint has been loaded before invoking this function. - - Args: - unet: nn.Cell - lora_ckpt_path: path to lora checkpoint - alpha: the strength of LoRA, typically in range [0, 1] - Returns: - unet with updated weights - - Note: expect format - Case 1: source from torch checkpoint - lora pname: - model.diffusion_model.input_blocks.1.2.temporal_transformer.transformer_blocks.0.attention_blocks.0.processor.to_out_lora.down.weight - = {attn_layer}{lora_postfix} - = {attn_layer}.processor.{to_q/k/v/out}_lora.{down/up}.weight - mm attn dense weight name: - model.diffusion_model.input_blocks.1.2.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_out.0.weight - = {attn_layer}.{to_q/k/v/out.0}.weight - Case 2: source from ms finetuned - lora pname: - model.diffusion_model.output_blocks.1.2.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_out.0.lora_down.weight - = {attn_layer}.{to_q/k/v/out.0}.lora_{down/up}.weight - remove .lora_{down/up} - """ - lora_pdict = ms.load_checkpoint(lora_ckpt_path) - unet_pdict = unet.parameters_dict() - - for lora_pname in lora_pdict: - is_from_torch = "_lora." in lora_pname - if ("lora.down." in lora_pname) or ("lora_down." in lora_pname): # skip lora.up - lora_down_pname = lora_pname - if is_from_torch: - lora_up_pname = lora_pname.replace("lora.down.", "lora.up.") - else: - lora_up_pname = lora_pname.replace("lora_down.", "lora_up.") - - # 1. locate the target attn dense layer weight (q/k/v/out) by param name - if is_from_torch: - attn_pname = ( - lora_pname.replace("processor.", "").replace("_lora", "").replace("down.", "").replace("up.", "") - ) - attn_pname = attn_pname.replace("to_out.", "to_out.0.") - else: - attn_pname = lora_pname.replace("lora_down.", "").replace("lora_up.", "") - - # 2. merge lora up and down weight to target dense layer weight - down_weight = lora_pdict[lora_down_pname] - up_weight = lora_pdict[lora_up_pname] - - dense_weight = unet_pdict[attn_pname].value() - merged_weight = dense_weight + alpha * ms.ops.matmul(up_weight, down_weight) - - unet_pdict[attn_pname].set_data(merged_weight) - - logger.info(f"Inspected LoRA rank: {down_weight.shape[0]}") +def is_url(string): + # Regex to check for URL patterns + url_pattern = re.compile(r"^(http|https|ftp)://") + return bool(url_pattern.match(string)) - return unet +def load_from_pretrained( + net: nn.Cell, + checkpoint: Union[str, dict], + ignore_net_params_not_loaded=False, + ensure_all_ckpt_params_loaded=False, + cache_dir: str = None, +): + """load checkpoint into network. -def merge_motion_lora_to_mm_pdict(mm_param_dict, lora_ckpt_path, alpha=1.0): - """ - Merge lora weights to montion module param dict. So that we don't need to load param dict to UNet twice. Args: - mm_param_dict: motion module param dict - lora_ckpt_path: path to lora checkpoint - alpha: the strength of LoRA, typically in range [0, 1] - Returns: - updated motion module param dict + net: network + checkpoint: local file path to checkpoint, or url to download checkpoint, or a dict for network parameters + ignore_net_params_not_loaded: set True for inference if only a part of network needs to be loaded, the flushing net-not-loaded warnings will disappear. + ensure_all_ckpt_params_loaded : set True for inference if you want to ensure no checkpoint param is missed in loading + cache_dir: directory to cache the downloaded checkpoint, only effective when `checkpoint` is a url. """ - lora_pdict = ms.load_checkpoint(lora_ckpt_path) - - for lora_pname in lora_pdict: - if "lora.down." in lora_pname: # skip lora.up - lora_down_pname = lora_pname - lora_up_pname = lora_pname.replace("lora.down.", "lora.up.") - - # 1. locate the target attn dense layer weight (q/k/v/out) by param name - attn_pname = ( - lora_pname.replace("processor.", "").replace("_lora", "").replace("down.", "").replace("up.", "") - ) - attn_pname = attn_pname.replace("to_out.", "to_out.0.") - - # 2. merge lora up and down weight to target dense layer weight - down_weight = lora_pdict[lora_down_pname] - up_weight = lora_pdict[lora_up_pname] - - dense_weight = mm_param_dict[attn_pname].value() - merged_weight = dense_weight + alpha * ms.ops.matmul(up_weight, down_weight) - - mm_param_dict[attn_pname].set_data(merged_weight) - - return mm_param_dict - - -def update_unet2d_params_for_unet3d(ckpt_param_dict, unet3d_type="adv2"): - # after injecting temporal moduels to unet2d cell, param name of some layers are changed. - # apply the change to ckpt param names as well to load all unet ckpt params to unet3d cell - - # map the name change from 2d to 3d, annotated from vimdiff compare, - if unet3d_type == "adv2": - prefix_mapping = { - "model.diffusion_model.middle_block.2": "model.diffusion_model.middle_block.3", - "model.diffusion_model.output_blocks.2.1": "model.diffusion_model.output_blocks.2.2", - "model.diffusion_model.output_blocks.5.2": "model.diffusion_model.output_blocks.5.3", - "model.diffusion_model.output_blocks.8.2": "model.diffusion_model.output_blocks.8.3", - "model.diffusion_model.out.0": "model.diffusion_model.conv_norm_out", - "model.diffusion_model.out.2.conv": "model.diffusion_model.out.1.conv", - } - elif unet3d_type == "adv1": - prefix_mapping = { - "model.diffusion_model.output_blocks.2.1": "model.diffusion_model.output_blocks.2.2", - "model.diffusion_model.output_blocks.5.2": "model.diffusion_model.output_blocks.5.3", - "model.diffusion_model.output_blocks.8.2": "model.diffusion_model.output_blocks.8.3", - "model.diffusion_model.out.0": "model.diffusion_model.conv_norm_out", - "model.diffusion_model.out.2.conv": "model.diffusion_model.out.1.conv", - } - - pnames = list(ckpt_param_dict.keys()) - for pname in pnames: - for prefix_2d, prefix_3d in prefix_mapping.items(): - if pname.startswith(prefix_2d): - new_pname = pname.replace(prefix_2d, prefix_3d) - ckpt_param_dict[new_pname] = ckpt_param_dict.pop(pname) - - return ckpt_param_dict - - -def load_motion_modules( - unet, motion_module_path, motion_lora_config=None, add_ldm_prefix=True, ldm_prefix="model.diffusion_model." -): - # load motion module weights if use mm - logger.info("Loading motion module from {}".format(motion_module_path)) - mm_state_dict = ms.load_checkpoint(motion_module_path) - - def _clear_insertion_from_training(param_name): - return param_name.replace("diffusion_model.diffusion_model.", "diffusion_model.").replace("._backbone.", ".") - - # add prefix (used in the whole sd model) to param if needed - mm_pnames = list(mm_state_dict.keys()) - for pname in mm_pnames: - if add_ldm_prefix: - if not pname.startswith(ldm_prefix): - new_pname = ldm_prefix + pname - # remove duplicated "diffusion_model" caused by saving mm only during training - new_pname = _clear_insertion_from_training(new_pname) - mm_state_dict[new_pname] = mm_state_dict.pop(pname) - - params_not_load, ckpt_not_load = load_param_into_net_with_filter( - unet, - mm_state_dict, - filter=mm_state_dict.keys(), - ) - if len(ckpt_not_load) > 0: - logger.warning( - "The following params in mm ckpt are not loaded into net: {}\nTotal: {}".format( - ckpt_not_load, len(ckpt_not_load) - ) - ) - assert len(ckpt_not_load) == 0, "All params in motion module must be loaded" - - # motion lora - if motion_lora_config is not None: - if motion_lora_config["path"] not in ["", None]: - _mlora_path, alpha = motion_lora_config["path"], motion_lora_config["alpha"] - logger.info("Loading motion lora from {}".format(_mlora_path)) - unet = merge_lora_to_unet(unet, _mlora_path, alpha) - - return unet - - -def load_adapter_lora(unet, adapter_lora_path, adapter_lora_alpha): - # load motion module weights if use mm - logger.info("Loading domain adapter lora module from {}".format(adapter_lora_path)) - unet = merge_lora_to_unet(unet, adapter_lora_path, adapter_lora_alpha) - - return unet - - -def load_controlnet(sd_model, controlnet_path, verbose=True): - logger.info("Loading sparse control encoder from {}".format(controlnet_path)) - controlnet_state_dict = ms.load_checkpoint(controlnet_path) - controlnet_state_dict = ( - controlnet_state_dict["controlnet"] if "controlnet" in controlnet_state_dict else controlnet_state_dict - ) - filter_list = list(controlnet_state_dict.keys()) - param_not_load, ckpt_not_load = load_param_into_net_with_filter(sd_model, controlnet_state_dict, filter=filter_list) - assert ( - len(ckpt_not_load) == 0 - ), f"All params in SD checkpoint must be loaded. but got these not loaded {ckpt_not_load}" - if verbose: - if len(param_not_load) > 0: - logger.info("Net params not loaded: {}".format([p for p in param_not_load if not p.startswith("adam")])) - return sd_model - + if isinstance(checkpoint, str): + if is_url(checkpoint): + url = checkpoint + cache_dir = os.path.join(os.path.expanduser("~"), ".mindspore/models") if cache_dir is None else cache_dir + os.makedirs(cache_dir, exist_ok=True) + DownLoad().download_url(url, path=cache_dir) + checkpoint = os.path.join(cache_dir, os.path.basename(url)) + if os.path.exists(checkpoint): + param_dict = ms.load_checkpoint(checkpoint) + else: + raise FileNotFoundError(f"{checkpoint} doesn't exist") + elif isinstance(checkpoint, dict): + param_dict = checkpoint + else: + raise TypeError(f"unknown checkpoint type: {checkpoint}") -def build_model_from_config(config, ckpt: str, is_training=False, use_motion_module=True): - def _load_model(_model, checkpoint, verbose=True, ignore_net_param_not_load_warning=False): - if isinstance(checkpoint, str): - if os.path.exists(checkpoint): - param_dict = ms.load_checkpoint(checkpoint) - else: - raise FileNotFoundError(f"{checkpoint} doesn't exist") - elif isinstance(checkpoint, dict): - param_dict = checkpoint + if param_dict: + if ignore_net_params_not_loaded: + filter = param_dict.keys() else: - raise TypeError(f"unknown checkpoint type: {checkpoint}") + filter = None + param_not_load, ckpt_not_load = load_param_into_net_with_filter(net, param_dict, filter=filter) - if param_dict: - if ignore_net_param_not_load_warning: - filter = param_dict.keys() - else: - filter = None - param_not_load, ckpt_not_load = load_param_into_net_with_filter(_model, param_dict, filter=filter) + if ensure_all_ckpt_params_loaded: assert ( len(ckpt_not_load) == 0 - ), f"All params in SD checkpoint must be loaded. but got these not loaded {ckpt_not_load}" - if verbose: - if len(param_not_load) > 0: - logger.info( - "Net params not loaded: {}".format([p for p in param_not_load if not p.startswith("adam")]) - ) - - model = instantiate_from_config(config.model) - if ckpt != "": - param_dict = ms.load_checkpoint(ckpt) - - # update param dict loading unet2d checkpoint to unet3d - if use_motion_module: - if config.model.params.unet_config.params.motion_module_mid_block: - unet3d_type = "adv2" - else: - unet3d_type = "adv1" - param_dict = update_unet2d_params_for_unet3d(param_dict, unet3d_type=unet3d_type) - - logger.info(f"Loading main model from {ckpt}") - _load_model(model, param_dict, ignore_net_param_not_load_warning=True) - else: - logger.warning("No pretarined weights loaded. Input checkpoint path is empty.") - - if not is_training: - model.set_train(False) - for param in model.trainable_params(): - param.requires_grad = False + ), f"All params in checkpoint must be loaded. but got these not loaded {ckpt_not_load}" - return model + if not ignore_net_params_not_loaded: + if len(param_not_load) > 0: + logger.info("Net params not loaded: {}".format([p for p in param_not_load if not p.startswith("adam")])) + logger.info("Checkpoint params not loaded: {}".format([p for p in ckpt_not_load if not p.startswith("adam")])) diff --git a/examples/opensora_hpcai/requirements.txt b/examples/opensora_hpcai/requirements.txt index 2980b29c41..ef1761de3e 100644 --- a/examples/opensora_hpcai/requirements.txt +++ b/examples/opensora_hpcai/requirements.txt @@ -18,3 +18,4 @@ tokenizers sentencepiece transformers pyav +mindcv=0.3.0 diff --git a/examples/opensora_hpcai/scripts/args_train.py b/examples/opensora_hpcai/scripts/args_train.py index f1f5718931..37edb4e241 100644 --- a/examples/opensora_hpcai/scripts/args_train.py +++ b/examples/opensora_hpcai/scripts/args_train.py @@ -142,9 +142,9 @@ def parse_train_args(parser): parser.add_argument( "--group_strategy", type=str, - default="norm_and_bias", + default=None, help="Grouping strategy for weight decay. If `norm_and_bias`, weight decay filter list is [beta, gamma, bias]. \ - If None, filter list is [layernorm, bias]. Default: norm_and_bias", + If None, filter list is [layernorm, bias]. Default: None", ) parser.add_argument("--weight_decay", default=1e-6, type=float, help="Weight decay.") diff --git a/examples/opensora_hpcai/scripts/args_train_vae.py b/examples/opensora_hpcai/scripts/args_train_vae.py new file mode 100644 index 0000000000..352a90483d --- /dev/null +++ b/examples/opensora_hpcai/scripts/args_train_vae.py @@ -0,0 +1,296 @@ +import argparse +import logging +import os +import sys + +import yaml + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +mindone_lib_path = os.path.abspath(os.path.join(__dir__, "../../../")) +sys.path.insert(0, mindone_lib_path) + +from opensora.utils.model_utils import _check_cfgs_in_parser, str2bool + +from mindone.utils.misc import to_abspath + +logger = logging.getLogger() + + +def parse_train_args(parser): + parser.add_argument( + "--config", + "-c", + default="", + type=str, + help="path to load a config yaml file that describes the training recipes which will override the default arguments", + ) + # the following args's defualt value will be overrided if specified in config yaml + + # data + parser.add_argument("--dataset_name", default="", type=str, help="dataset name") + parser.add_argument( + "--csv_path", + default="", + type=str, + help="path to csv annotation file. columns: video, caption. \ + video indicates the relative path of video file in video_folder. caption - the text caption for video", + ) + parser.add_argument("--video_column", default="video", type=str, help="name of column for videos saved in csv file") + parser.add_argument("--random_crop", default=False, type=str2bool, help="randonly crop the image") + parser.add_argument("--flip", default=False, type=str2bool, help="flip the image") + + parser.add_argument( + "--caption_column", default="caption", type=str, help="name of column for captions saved in csv file" + ) + parser.add_argument("--video_folder", default="", type=str, help="root dir for the video data") + parser.add_argument("--output_path", default="output/", type=str, help="output directory to save training results") + parser.add_argument( + "--add_datetime", default=True, type=str, help="If True, add datetime subfolder under output_path" + ) + # model + parser.add_argument("--model_type", default="OpenSora-VAE-v1.2", type=str, help="VAE model type") + parser.add_argument("--freeze_vae_2d", default=True, type=str2bool, help="Freeze 2d vae") + parser.add_argument( + "--use_discriminator", default=False, type=str2bool, help="Use discriminator for adversarial training." + ) + parser.add_argument( + "--pretrained_model_path", + default="", + type=str, + help="Specify the pretrained model path", + ) + parser.add_argument("--perceptual_loss_weight", default=0.1, type=float, help="perceptual (lpips) loss weight") + parser.add_argument("--kl_loss_weight", default=1.0e-6, type=float, help="KL loss weight") + parser.add_argument("--use_real_rec_loss", default=False, type=str2bool, help="use vae 2d x reconstruction loss") + parser.add_argument("--use_z_rec_loss", default=False, type=str2bool, help="use spatial vae z reconstruction loss") + parser.add_argument( + "--use_image_identity_loss", + default=False, + type=str2bool, + help="use image identity reguralization loss for temporal vae encoder", + ) + # data + parser.add_argument("--mixed_strategy", type=str, default=None, help="video and image mixed strategy") + parser.add_argument( + "--mixed_image_ratio", default=0.0, type=float, help="image ratio in mixed video and image data training" + ) + + # ms + parser.add_argument("--debug", type=str2bool, default=False, help="Execute inference in debug mode.") + parser.add_argument("--device_target", type=str, default="Ascend", help="Ascend or GPU") + parser.add_argument("--max_device_memory", type=str, default=None, help="e.g. `30GB` for 910a, `59GB` for 910b") + parser.add_argument("--mode", default=0, type=int, help="Specify the mode: 0 for graph mode, 1 for pynative mode") + parser.add_argument("--use_parallel", default=False, type=str2bool, help="use parallel") + parser.add_argument( + "--parallel_mode", default="data", type=str, choices=["data", "optim"], help="parallel mode: data, optim" + ) + parser.add_argument("--jit_level", default="O0", type=str, help="O0 kbk, O1 dvm, O2 ge") + + # training hyper-params + parser.add_argument( + "--resume", + default=False, + type=str, + help="It can be a string for path to resume checkpoint, or a bool False for not resuming.(default=False)", + ) + parser.add_argument("--optim", default="adamw_re", type=str, help="optimizer") + parser.add_argument( + "--betas", + type=float, + nargs="+", + default=[0.9, 0.999], + help="Specify the [beta1, beta2] parameter for the AdamW optimizer.", + ) + parser.add_argument( + "--optim_eps", type=float, default=1e-8, help="Specify the eps parameter for the AdamW optimizer." + ) + parser.add_argument( + "--group_strategy", + type=str, + default=None, + help="Grouping strategy for weight decay. If `norm_and_bias`, weight decay filter list is [beta, gamma, bias]. \ + If None, filter list is [layernorm, bias], Default: None", + ) + parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay.") + parser.add_argument("--seed", default=3407, type=int, help="data path") + parser.add_argument("--warmup_steps", default=1000, type=int, help="warmup steps") + parser.add_argument("--batch_size", default=10, type=int, help="batch size") + parser.add_argument( + "--micro_batch_size", + type=int, + default=4, + help="If not None, split batch_size*num_frames into smaller ones for VAE encoding to reduce memory limitation", + ) + parser.add_argument( + "--micro_frame_size", + type=int, + default=17, + help="If not None, split batch_size*num_frames into smaller ones for VAE encoding to reduce memory limitation. Used by temporal vae", + ) + parser.add_argument("--start_learning_rate", default=1e-5, type=float, help="The initial learning rate for Adam.") + parser.add_argument("--end_learning_rate", default=1e-7, type=float, help="The end learning rate for Adam.") + parser.add_argument( + "--scale_lr", default=False, type=str2bool, help="scale base-lr by ngpu * batch_size * n_accumulate" + ) + parser.add_argument("--decay_steps", default=0, type=int, help="lr decay steps.") + parser.add_argument("--scheduler", default="cosine_decay", type=str, help="scheduler.") + parser.add_argument("--pre_patchify", default=False, type=str2bool, help="Training with patchified latent.") + parser.add_argument( + "--max_image_size", default=512, type=int, help="Max image size for patchified latent training." + ) + + # dataloader params + parser.add_argument("--dataset_sink_mode", default=False, type=str2bool, help="sink mode") + parser.add_argument("--sink_size", default=-1, type=int, help="dataset sink size. If -1, sink size = dataset size.") + parser.add_argument( + "--epochs", + default=10, + type=int, + help="epochs. If dataset_sink_mode is on, epochs is with respect to dataset sink size. Otherwise, it's w.r.t the dataset size.", + ) + parser.add_argument( + "--train_steps", default=-1, type=int, help="If not -1, limit the number of training steps to the set value" + ) + parser.add_argument("--init_loss_scale", default=65536, type=float, help="loss scale") + parser.add_argument("--loss_scale_factor", default=2, type=float, help="loss scale factor") + parser.add_argument("--scale_window", default=2000, type=float, help="scale window") + parser.add_argument("--gradient_accumulation_steps", default=1, type=int, help="gradient accumulation steps") + # parser.add_argument("--cond_stage_trainable", default=False, type=str2bool, help="whether text encoder is trainable") + parser.add_argument("--use_ema", default=False, type=str2bool, help="whether use EMA") + parser.add_argument("--ema_decay", default=0.9999, type=float, help="ema decay ratio") + parser.add_argument("--clip_grad", default=False, type=str2bool, help="whether apply gradient clipping") + parser.add_argument( + "--use_recompute", + default=False, + type=str2bool, + help="whether use recompute.", + ) + parser.add_argument( + "--num_recompute_blocks", + default=None, + type=int, + help="If None, all stdit blocks will be applied with recompute (gradient checkpointing). If int, the first N blocks will be applied with recompute", + ) + parser.add_argument( + "--dtype", + default="fp16", + type=str, + choices=["bf16", "fp16", "fp32"], + help="what computation data type to use for latte. Default is `fp16`, which corresponds to ms.float16", + ) + parser.add_argument( + "--vae_keep_gn_fp32", + default=True, + type=str2bool, + help="whether keep GroupNorm in fp32.", + ) + parser.add_argument( + "--global_bf16", + default=False, + type=str2bool, + help="Experimental. If True, dtype will be overrided, operators will be computered in bf16 if they are supported by CANN", + ) + parser.add_argument( + "--vae_param_dtype", + default="fp32", + type=str, + choices=["bf16", "fp16", "fp32"], + help="what param data type to use for vae. Default is `fp32`, which corresponds to ms.float32", + ) + parser.add_argument( + "--amp_level", + default="O2", + type=str, + help="mindspore amp level, O1: most fp32, only layers in whitelist compute in fp16 (dense, conv, etc); \ + O2: most fp16, only layers in blacklist compute in fp32 (batch norm etc)", + ) + parser.add_argument("--vae_amp_level", default="O2", type=str, help="O2 or O3") + parser.add_argument( + "--vae_checkpoint", + type=str, + default="models/sd-vae-ft-ema.ckpt", + help="VAE checkpoint file path which is used to load vae weight.", + ) + parser.add_argument( + "--sd_scale_factor", type=float, default=0.18215, help="VAE scale factor of Stable Diffusion model." + ) + parser.add_argument("--image_size", default=256, type=int, nargs="+", help="the image size used to initiate model") + parser.add_argument("--num_frames", default=16, type=int, help="the num of frames used to initiate model") + parser.add_argument("--frame_stride", default=3, type=int, help="frame sampling stride") + parser.add_argument("--mask_ratios", type=dict, help="Masking ratios") + parser.add_argument("--bucket_config", type=dict, help="Multi-resolution bucketing configuration") + parser.add_argument("--num_parallel_workers", default=12, type=int, help="num workers for data loading") + parser.add_argument( + "--data_multiprocessing", + default=False, + type=str2bool, + help="If True, use multiprocessing for data processing. Default: multithreading.", + ) + parser.add_argument("--max_rowsize", default=64, type=int, help="max rowsize for data loading") + parser.add_argument( + "--enable_flash_attention", + default=None, + type=str2bool, + help="whether to enable flash attention.", + ) + parser.add_argument("--drop_overflow_update", default=True, type=str2bool, help="drop overflow update") + parser.add_argument("--loss_scaler_type", default="dynamic", type=str, help="dynamic or static") + parser.add_argument( + "--max_grad_norm", + default=1.0, + type=float, + help="max gradient norm for clipping, effective when `clip_grad` enabled.", + ) + parser.add_argument("--ckpt_save_interval", default=1, type=int, help="save checkpoint every this epochs") + parser.add_argument( + "--ckpt_save_steps", + default=-1, + type=int, + help="save checkpoint every this steps. If -1, use ckpt_save_interval will be used.", + ) + parser.add_argument("--ckpt_max_keep", default=10, type=int, help="Maximum number of checkpoints to keep") + parser.add_argument( + "--step_mode", + default=False, + type=str2bool, + help="whether save ckpt by steps. If False, save ckpt by epochs.", + ) + parser.add_argument("--profile", default=False, type=str2bool, help="Profile or not") + parser.add_argument( + "--log_level", + type=str, + default="logging.INFO", + help="log level, options: logging.DEBUG, logging.INFO, logging.WARNING, logging.ERROR", + ) + parser.add_argument( + "--log_interval", + default=1, + type=int, + help="log interval in the unit of data sink size.. E.g. if data sink size = 10, log_inteval=2, log every 20 steps", + ) + return parser + + +def parse_args(): + parser = argparse.ArgumentParser() + parser = parse_train_args(parser) + + __dir__ = os.path.dirname(os.path.abspath(__file__)) + abs_path = os.path.abspath(os.path.join(__dir__, "..")) + default_args = parser.parse_args() + if default_args.config: + default_args.config = to_abspath(abs_path, default_args.config) + with open(default_args.config, "r") as f: + cfg = yaml.safe_load(f) + _check_cfgs_in_parser(cfg, parser) + parser.set_defaults(**cfg) + args = parser.parse_args() + # convert to absolute path, necessary for modelarts + args.csv_path = to_abspath(abs_path, args.csv_path) + args.video_folder = to_abspath(abs_path, args.video_folder) + args.output_path = to_abspath(abs_path, args.output_path) + args.pretrained_model_path = to_abspath(abs_path, args.pretrained_model_path) + args.vae_checkpoint = to_abspath(abs_path, args.vae_checkpoint) + print(args) + + return args diff --git a/examples/opensora_hpcai/scripts/inference_vae.py b/examples/opensora_hpcai/scripts/inference_vae.py new file mode 100644 index 0000000000..264982d9ef --- /dev/null +++ b/examples/opensora_hpcai/scripts/inference_vae.py @@ -0,0 +1,307 @@ +# flake8: noqa +""" +Infer and evaluate autoencoders +""" +import argparse +import logging +import os +import sys +import time + +import imageio +import numpy as np + +from mindspore import nn, ops + +# mindone_dir = '/home/mindocr/yx/mindone' +mindone_dir = "/home_host/yx/mindone" +sys.path.insert(0, mindone_dir) + +# from ae.models.lpips import LPIPS +from omegaconf import OmegaConf +from PIL import Image +from skimage.metrics import peak_signal_noise_ratio as calc_psnr +from skimage.metrics import structural_similarity as calc_ssim +from tqdm import tqdm + +import mindspore as ms + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +mindone_lib_path = os.path.abspath(os.path.join(__dir__, "../../../")) +sys.path.insert(0, mindone_lib_path) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, ".."))) + +from opensora.datasets.vae_dataset import create_dataloader +from opensora.models.vae.vae import SD_CONFIG, SDXL_CONFIG, OpenSoraVAE_V1_2, VideoAutoencoderKL + +from mindone.utils.amp import auto_mixed_precision +from mindone.utils.config import instantiate_from_config, str2bool +from mindone.utils.logger import set_logger + +logger = logging.getLogger(__name__) + + +def postprocess(x, trim=True): + # postprocess for computing metrics + pixels = (x + 1) * 127.5 + pixels = np.clip(pixels, 0, 255).astype(np.uint8) + + if len(pixels.shape) == 4: + # b, c, h, w -> b h w c + return np.transpose(pixels, (0, 2, 3, 1)) + else: + # b c t h w -> b t h w c -> b*t h w c + b, c, t, h, w = pixels.shape + pixels = np.transpose(pixels, (0, 2, 3, 4, 1)) + pixels = np.reshape(pixels, (b * t, h, w, c)) + return pixels + + +def visualize_image(recons, x=None, save_fn="tmp_vae_recons"): + # x: (b h w c) + for i in range(recons.shape[0]): + if x is not None: + out = np.concatenate((x[i], recons[i]), axis=-2) + else: + out = recons[i] + Image.fromarray(out).save(f"{save_fn}-{i:02d}.png") + + +def visualize_video(recons, x=None, save_fn="tmp_vae3d_recons", fps=15): + # x: (b t h w c) + for i in range(recons.shape[0]): + if x is not None: + out = np.concatenate((x[i], recons[i]), axis=-2) + else: + out = recons[i] + save_fp = f"{save_fn}-{i:02d}.gif" + imageio.mimsave(save_fp, out, duration=1 / fps, loop=0) + + +def main(args): + ascend_config = {"precision_mode": "must_keep_origin_dtype"} + ms.set_context(mode=args.mode, ascend_config=ascend_config) + set_logger(name="", output_dir=args.output_path, rank=0) + + # build model + if args.use_temporal_vae: + model = OpenSoraVAE_V1_2( + micro_batch_size=4, + micro_frame_size=17, + ckpt_path=args.ckpt_path, + freeze_vae_2d=True, + ) + else: + model = VideoAutoencoderKL(config=SDXL_CONFIG, ckpt_path=args.ckpt_path, micro_batch_size=4) + + model.set_train(False) + logger.info(f"Loaded checkpoint from {args.ckpt_path}") + + # if args.eval_loss: + # lpips_loss_fn = LPIPS() + + if args.dtype != "fp32": + amp_level = "O2" + dtype = {"fp16": ms.float16, "bf16": ms.bfloat16}[args.dtype] + # FIXME: due to AvgPool and ops.interpolate doesn't support bf16, we add them to fp32 cells + custom_fp32_cells = [nn.GroupNorm, nn.AvgPool2d, nn.Upsample] + model = auto_mixed_precision(model, amp_level, dtype, custom_fp32_cells) + logger.info(f"Set mixed precision to O2 with dtype={args.dtype}") + else: + amp_level = "O0" + + # build dataset + if isinstance(args.image_size, int): + image_size = args.image_size + else: + if len(args.image_size) == 2: + assert args.image_size[0] == args.image_size[1], "Currently only h==w is supported" + image_size = args.image_size[0] + + ds_config = dict( + csv_path=args.csv_path, + data_folder=args.video_folder, + size=image_size, + crop_size=image_size, + sample_n_frames=args.num_frames, + sample_stride=args.frame_stride, + video_column=args.video_column, + random_crop=False, + flip=False, + ) + dataset = create_dataloader( + ds_config, + args.batch_size, + mixed_strategy=None, + mixed_image_ratio=0.0, + num_parallel_workers=8, + max_rowsize=256, + shuffle=False, + device_num=1, + rank_id=0, + drop_remainder=False, + ) + num_batches = dataset.get_dataset_size() + + ds_iter = dataset.create_dict_iterator(1) + + logger.info("Inferene begins") + mean_infer_time = 0 + mean_psnr = 0 + mean_ssim = 0 + mean_lpips = 0 + mean_recon = 0 + num_samples = 0 + for step, data in tqdm(enumerate(ds_iter)): + x = data["video"] + start_time = time.time() + + z = model.encode(x) + if not args.encode_only: + if args.use_temporal_vae: + recons = model.decode(z, num_frames=args.num_frames) + else: + recons = model.decode(z) + + # adapt to bf16 + recons = recons.to(ms.float32) + + infer_time = time.time() - start_time + mean_infer_time += infer_time + logger.info(f"Infer time: {infer_time}") + + if not args.encode_only: + # if args.dataset_name == 'image' and args.expand_dim_t: + # # b c t h w -> b c h w + # x = x[:,:,0,:,:] + # recons= recons[:,:,0,:,:] + is_video = len(recons.shape) == 5 and (recons.shape[-3] > 1) + t = recons.shape[-3] if is_video else 1 + + recons_rgb = postprocess(recons.asnumpy()) + x_rgb = postprocess(x.asnumpy()) + + psnr_cur = [calc_psnr(x_rgb[i], recons_rgb[i]) for i in range(x_rgb.shape[0])] + ssim_cur = [ + calc_ssim(x_rgb[i], recons_rgb[i], data_range=255, channel_axis=-1, multichannel=True) + for i in range(x_rgb.shape[0]) + ] + mean_psnr += sum(psnr_cur) + mean_ssim += sum(ssim_cur) + num_samples += x_rgb.shape[0] + + if args.eval_loss: + recon_loss = np.abs((x - recons).asnumpy()) + lpips_loss = lpips_loss_fn(x, recons).asnumpy() + mean_recon += recon_loss.mean() + mean_lpips += lpips_loss.mean() + + if args.save_vis: + save_fn = os.path.join( + args.output_path, "{}-{}".format(os.path.basename(args.video_folder), f"step{step:03d}") + ) + if not is_video: + visualize_image(recons_rgb, x_rgb, save_fn=save_fn) + else: + bt, h, w, c = recons_rgb.shape + recons_rgb_vis = np.reshape(recons_rgb, (bt // t, t, h, w, c)) + x_rgb_vis = np.reshape(x_rgb, (bt // t, t, h, w, c)) + visualize_video(recons_rgb_vis, x_rgb_vis, save_fn=save_fn) + + mean_infer_time /= num_batches + logger.info(f"Mean infer time: {mean_infer_time}") + logger.info(f"Done. Results saved in {args.output_path}") + + if not args.encode_only: + mean_psnr /= num_samples + mean_ssim /= num_samples + logger.info(f"mean psnr:{mean_psnr:.4f}") + logger.info(f"mean ssim:{mean_ssim:.4f}") + + if args.eval_loss: + mean_recon /= num_batches + mean_lpips /= num_batches + logger.info(f"mean recon loss: {mean_recon:.4f}") + logger.info(f"mean lpips loss: {mean_lpips:.4f}") + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_config", + default="configs/autoencoder_kl_f8.yaml", + type=str, + help="model architecture config", + ) + parser.add_argument( + "--ckpt_path", default="outputs/vae_train/ckpt/vae_kl_f8-e10.ckpt", type=str, help="checkpoint path" + ) + parser.add_argument( + "--csv_path", + default=None, + type=str, + help="path to csv annotation file. If None, will get videos from the folder of `data_path`", + ) + parser.add_argument("--video_folder", default=None, type=str, help="folder of videos") + parser.add_argument( + "--output_path", default="samples/vae_recons", type=str, help="output directory to save inference results" + ) + parser.add_argument("--num_frames", default=17, type=int, help="num frames") + parser.add_argument("--frame_stride", default=1, type=int, help="frame sampling stride") + parser.add_argument( + "--expand_dim_t", + default=False, + type=str2bool, + help="expand temporal axis for image data, used for vae 3d inference with image data", + ) + parser.add_argument("--image_size", default=256, type=int, help="image rescale size") + # parser.add_argument("--crop_size", default=256, type=int, help="image crop size") + + parser.add_argument("--batch_size", default=1, type=int, help="batch size") + parser.add_argument("--num_parallel_workers", default=8, type=int, help="num workers for data loading") + parser.add_argument( + "--eval_loss", + default=False, + type=str2bool, + help="whether measure loss including reconstruction, kl, perceptual loss", + ) + parser.add_argument("--save_vis", default=True, type=str2bool, help="whether save reconstructed images") + parser.add_argument("--use_temporal_vae", default=True, type=str2bool, help="if False, just use spatial vae") + parser.add_argument("--encode_only", default=False, type=str2bool, help="only encode to save z or distribution") + parser.add_argument("--video_column", default="video", type=str, help="name of column for videos saved in csv file") + parser.add_argument( + "--mixed_strategy", + type=str, + default=None, + choices=[None, "mixed_video_image", "image_only"], + help="video and image mixed strategy.", + ) + parser.add_argument( + "--mixed_image_ratio", default=0.0, type=float, help="image ratio in mixed video and image data training" + ) + parser.add_argument( + "--save_z_dist", + default=False, + type=str2bool, + help="If True, save z distribution, mean and logvar. Otherwise, save z after sampling.", + ) + # ms related + parser.add_argument("--mode", default=0, type=int, help="Specify the mode: 0 for graph mode, 1 for pynative mode") + parser.add_argument( + "--dtype", + default="fp32", + type=str, + choices=["fp32", "fp16", "bf16"], + help="mixed precision type, if fp32, all layer precision is float32 (amp_level=O0), \ + if bf16 or fp16, amp_level==O2, part of layers will compute in bf16 or fp16 such as matmul, dense, conv.", + ) + parser.add_argument("--device_target", type=str, default="Ascend", help="Ascend or GPU") + + args = parser.parse_args() + + return args + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/examples/opensora_hpcai/scripts/train_vae.py b/examples/opensora_hpcai/scripts/train_vae.py new file mode 100644 index 0000000000..31f75af6de --- /dev/null +++ b/examples/opensora_hpcai/scripts/train_vae.py @@ -0,0 +1,455 @@ +import logging +import os +import shutil +import sys +import time +from typing import Tuple + +import yaml + +import mindspore as ms +from mindspore import Model, nn +from mindspore.communication.management import get_group_size, get_rank, init +from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell +from mindspore.train.callback import TimeMonitor + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +mindone_lib_path = os.path.abspath(os.path.join(__dir__, "../../../")) +sys.path.insert(0, mindone_lib_path) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, ".."))) + +from args_train_vae import parse_args +from opensora.datasets.vae_dataset import create_dataloader +from opensora.models.layers.operation_selector import set_dynamic_mode +from opensora.models.vae.losses import GeneratorWithLoss +from opensora.models.vae.vae import OpenSoraVAE_V1_2 + +from mindone.trainers.callback import EvalSaveCallback, OverflowMonitor, ProfilerCallback +from mindone.trainers.checkpoint import CheckpointManager, resume_train_network +from mindone.trainers.ema import EMA +from mindone.trainers.lr_schedule import create_scheduler +from mindone.trainers.optim import create_optimizer +from mindone.trainers.train_step import TrainOneStepWrapper +from mindone.utils.amp import auto_mixed_precision +from mindone.utils.logger import set_logger +from mindone.utils.params import count_params +from mindone.utils.seed import set_random_seed + +os.environ["HCCL_CONNECT_TIMEOUT"] = "6000" +os.environ["MS_ASCEND_CHECK_OVERFLOW_MODE"] = "INFNAN_MODE" + +logger = logging.getLogger(__name__) + + +def create_loss_scaler(loss_scaler_type, init_loss_scale, loss_scale_factor=2, scale_window=1000): + if args.loss_scaler_type == "dynamic": + loss_scaler = DynamicLossScaleUpdateCell( + loss_scale_value=args.init_loss_scale, scale_factor=args.loss_scale_factor, scale_window=args.scale_window + ) + elif args.loss_scaler_type == "static": + loss_scaler = nn.FixedLossScaleUpdateCell(args.init_loss_scale) + else: + raise ValueError + + return loss_scaler + + +def init_env( + mode: int = ms.GRAPH_MODE, + seed: int = 42, + distributed: bool = False, + max_device_memory: str = None, + device_target: str = "Ascend", + parallel_mode: str = "data", + jit_level: str = "O2", + global_bf16: bool = False, + dynamic_shape: bool = False, + debug: bool = False, +) -> Tuple[int, int]: + """ + Initialize MindSpore environment. + + Args: + mode: MindSpore execution mode. Default is 0 (ms.GRAPH_MODE). + seed: The seed value for reproducibility. Default is 42. + distributed: Whether to enable distributed training. Default is False. + Returns: + A tuple containing the device ID, rank ID and number of devices. + """ + set_random_seed(seed) + + if debug and mode == ms.GRAPH_MODE: # force PyNative mode when debugging + logger.warning("Debug mode is on, switching execution mode to PyNative.") + mode = ms.PYNATIVE_MODE + + if max_device_memory is not None: + ms.set_context(max_device_memory=max_device_memory) + + # ms.set_context(mempool_block_size="55GB") + # ms.set_context(pynative_synchronize=True) + if distributed: + ms.set_context( + mode=mode, + device_target=device_target, + ) + if parallel_mode == "optim": + print("use optim parallel") + ms.set_auto_parallel_context( + parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL, + enable_parallel_optimizer=True, + ) + init() + device_num = get_group_size() + rank_id = get_rank() + else: + init() + device_num = get_group_size() + rank_id = get_rank() + logger.debug(f"rank_id: {rank_id}, device_num: {device_num}") + ms.reset_auto_parallel_context() + + ms.set_auto_parallel_context( + parallel_mode=ms.ParallelMode.DATA_PARALLEL, + gradients_mean=True, + device_num=device_num, + ) + + var_info = ["device_num", "rank_id", "device_num / 8", "rank_id / 8"] + var_value = [device_num, rank_id, int(device_num / 8), int(rank_id / 8)] + logger.info(dict(zip(var_info, var_value))) + + else: + device_num = 1 + rank_id = 0 + ms.set_context( + mode=mode, + device_target=device_target, + pynative_synchronize=debug, + ) + + if mode == 0: + ms.set_context(jit_config={"jit_level": jit_level}) + + if global_bf16: + # only effective in GE mode, i.e. jit_level: O2 + ms.set_context(ascend_config={"precision_mode": "allow_mix_precision_bf16"}) + + if dynamic_shape: + print("Dynamic shape mode enabled, repeat_interleave/split/chunk will be called from mint module") + set_dynamic_mode(True) + if mode == 0: + # FIXME: this is a temp fix for dynamic shape training in graph mode. may remove in future version. + # can append adamw fusion flag if use nn.AdamW optimzation for acceleration + ms.set_context(graph_kernel_flags="--disable_packet_ops=Reshape") + print("D--: disable reshape packet") + + return rank_id, device_num + + +def main(args): + # 1. init + rank_id, device_num = init_env( + args.mode, + seed=args.seed, + distributed=args.use_parallel, + device_target=args.device_target, + max_device_memory=args.max_device_memory, + parallel_mode=args.parallel_mode, + jit_level=args.jit_level, + global_bf16=args.global_bf16, + dynamic_shape=(args.mixed_strategy == "mixed_video_random"), + debug=args.debug, + ) + set_logger(name="", output_dir=args.output_path, rank=rank_id, log_level=eval(args.log_level)) + + # 2. build data loader + + if isinstance(args.image_size, int): + image_size = args.image_size + else: + if len(args.image_size) == 2: + assert args.image_size[0] == args.image_size[1], "Currently only h==w is supported" + image_size = args.image_size[0] + + ds_config = dict( + csv_path=args.csv_path, + data_folder=args.video_folder, + size=image_size, + crop_size=image_size, + sample_n_frames=args.num_frames, + sample_stride=args.frame_stride, + video_column=args.video_column, + random_crop=args.random_crop, + flip=args.flip, + ) + dataloader = create_dataloader( + ds_config, + args.batch_size, + mixed_strategy=args.mixed_strategy, + mixed_image_ratio=args.mixed_image_ratio, + num_parallel_workers=args.num_parallel_workers, + max_rowsize=256, + shuffle=True, + device_num=device_num, + rank_id=rank_id, + drop_remainder=True, + ) + dataset_size = dataloader.get_dataset_size() + logger.info(f"Num batches: {dataset_size}") + + # 3. build models + if args.model_type == "OpenSoraVAE_V1_2": + logger.info(f"Loading autoencoder from {args.pretrained_model_path}") + if args.micro_frame_size != 17: + logger.warning( + "If you are finetuning VAE3d pretrained from OpenSora v1.2, it's safer to set micro_frame_size to 17 for consistency." + ) + ae = OpenSoraVAE_V1_2( + micro_batch_size=args.micro_batch_size, + micro_frame_size=args.micro_frame_size, + ckpt_path=args.pretrained_model_path, + freeze_vae_2d=args.freeze_vae_2d, + cal_loss=True, + use_recompute=args.use_recompute, + ) + else: + raise NotImplementedError("Only OpenSoraVAE_V1_2 is supported for vae training currently") + + if args.use_discriminator: + logging.error("Discriminator is not used or supported in OpenSora v1.2") + + # mixed precision + # TODO: set softmax, sigmoid computed in FP32. manually set inside network since they are ops, instead of layers whose precision will be set by AMP level. + if args.dtype in ["fp16", "bf16"]: + dtype = {"fp16": ms.float16, "bf16": ms.bfloat16}[args.dtype] + ae = auto_mixed_precision( + ae, + args.amp_level, + dtype, + custom_fp32_cells=[nn.GroupNorm] if args.vae_keep_gn_fp32 else [], + ) + + # 4. build net with loss + ae_with_loss = GeneratorWithLoss( + ae, + kl_weight=args.kl_loss_weight, + perceptual_weight=args.perceptual_loss_weight, + use_real_rec_loss=args.use_real_rec_loss, + use_z_rec_loss=args.use_z_rec_loss, + use_image_identity_loss=args.use_image_identity_loss, + dtype=args.dtype, + ) + + tot_params, trainable_params = count_params(ae_with_loss) + logger.info("Total params {:,}; Trainable params {:,}".format(tot_params, trainable_params)) + + # 5. build training utils + # torch scale lr by: model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr + if args.scale_lr: + learning_rate = args.start_learning_rate * args.batch_size * args.gradient_accumulation_steps * device_num + logger.info(f"Learning rate is scaled to {learning_rate}") + else: + learning_rate = args.start_learning_rate + if not args.decay_steps: + args.decay_steps = max(1, args.epochs * dataset_size - args.warmup_steps) + + if args.scheduler != "constant": + assert ( + args.optim != "adamw_exp" + ), "For dynamic LR, mindspore.experimental.optim.AdamW needs to work with LRScheduler" + lr = create_scheduler( + steps_per_epoch=dataset_size, + name=args.scheduler, + lr=learning_rate, + end_lr=args.end_learning_rate, + warmup_steps=args.warmup_steps, + decay_steps=args.decay_steps, + num_epochs=args.epochs, + ) + else: + lr = learning_rate + + # build optimizer + update_logvar = False # in torch, ae_with_loss.logvar is not updated. + if update_logvar: + ae_params_to_update = [ae_with_loss.autoencoder.trainable_params(), ae_with_loss.logvar] + else: + ae_params_to_update = ae_with_loss.autoencoder.trainable_params() + optim_ae = create_optimizer( + ae_params_to_update, + name=args.optim, + betas=args.betas, + group_strategy=args.group_strategy, + weight_decay=args.weight_decay, + lr=lr, + ) + loss_scaler_ae = create_loss_scaler( + args.loss_scaler_type, args.init_loss_scale, args.loss_scale_factor, args.scale_window + ) + + ema = ( + EMA( + ae, + ema_decay=args.ema_decay, + offloading=False, + ) + if args.use_ema + else None + ) + + # resume training states + # TODO: resume Discriminator if used + ckpt_dir = os.path.join(args.output_path, "ckpt") + os.makedirs(ckpt_dir, exist_ok=True) + start_epoch = 0 + if args.resume: + resume_ckpt = os.path.join(ckpt_dir, "train_resume.ckpt") if isinstance(args.resume, bool) else args.resume + + start_epoch, loss_scale, cur_iter, last_overflow_iter = resume_train_network( + ae_with_loss, optim_ae, resume_ckpt + ) + loss_scaler_ae.loss_scale_value = loss_scale + loss_scaler_ae.cur_iter = cur_iter + loss_scaler_ae.last_overflow_iter = last_overflow_iter + logger.info(f"Resume training from {resume_ckpt}") + + # training step + training_step_ae = TrainOneStepWrapper( + ae_with_loss, + optimizer=optim_ae, + scale_sense=loss_scaler_ae, + drop_overflow_update=args.drop_overflow_update, + gradient_accumulation_steps=args.gradient_accumulation_steps, + clip_grad=args.clip_grad, + clip_norm=args.max_grad_norm, + ema=ema, + ) + + # support dynamic shape in graph mode + if args.mode == 0 and args.mixed_strategy == "mixed_video_random": + # (b c t h w), drop_remainder so bs fixed + # videos = ms.Tensor(shape=[args.batch_size, 3, None, image_size, image_size], dtype=ms.float32) + videos = ms.Tensor(shape=[None, 3, None, image_size, image_size], dtype=ms.float32) + training_step_ae.set_inputs(videos) + logger.info("Dynamic inputs are initialized for mixed_video_random training in Graph mode!") + + if rank_id == 0: + key_info = "Key Settings:\n" + "=" * 50 + "\n" + key_info += "\n".join( + [ + f"MindSpore mode[GRAPH(0)/PYNATIVE(1)]: {args.mode}", + f"Distributed mode: {args.use_parallel}", + f"amp level: {args.amp_level}", + f"dtype: {args.dtype}", + f"csv path: {args.csv_path}", + f"Video folder: {args.video_folder}", + f"Learning rate: {learning_rate}", + f"Batch size: {args.batch_size}", + f"Rescale size: {args.image_size}", + f"Weight decay: {args.weight_decay}", + f"Grad accumulation steps: {args.gradient_accumulation_steps}", + f"Num epochs: {args.epochs}", + f"Loss scaler: {args.loss_scaler_type}", + f"Init loss scale: {args.init_loss_scale}", + f"Grad clipping: {args.clip_grad}", + f"Max grad norm: {args.max_grad_norm}", + f"EMA: {args.use_ema}", + ] + ) + key_info += "\n" + "=" * 50 + logger.info(key_info) + + # 6. training process + use_flexible_train = False + if not use_flexible_train: + model = Model(training_step_ae) + + # callbacks + callback = [TimeMonitor(args.log_interval)] + ofm_cb = OverflowMonitor() + callback.append(ofm_cb) + + if rank_id == 0: + save_cb = EvalSaveCallback( + network=ae, + rank_id=rank_id, + ckpt_save_dir=ckpt_dir, + ema=ema, + ckpt_save_policy="latest_k", + ckpt_max_keep=args.ckpt_max_keep, + ckpt_save_interval=args.ckpt_save_interval, + log_interval=args.log_interval, + start_epoch=start_epoch, + model_name="vae_3d", + record_lr=False, + ) + callback.append(save_cb) + if args.profile: + callback.append(ProfilerCallback()) + + logger.info("Start training...") + # backup config files + shutil.copyfile(args.config, os.path.join(args.output_path, os.path.basename(args.config))) + + with open(os.path.join(args.output_path, "args.yaml"), "w") as f: + yaml.safe_dump(vars(args), stream=f, default_flow_style=False, sort_keys=False) + + model.train( + args.epochs, + dataloader, + callbacks=callback, + dataset_sink_mode=args.dataset_sink_mode, + # sink_size=args.sink_size, + initial_epoch=start_epoch, + ) + else: + if rank_id == 0: + ckpt_manager = CheckpointManager(ckpt_dir, "latest_k", k=args.ckpt_max_keep) + # output_numpy=True ? + ds_iter = dataloader.create_dict_iterator(args.epochs - start_epoch) + + for epoch in range(start_epoch, args.epochs): + start_time_e = time.time() + for step, data in enumerate(ds_iter): + start_time_s = time.time() + x = data["video"] + + global_step = epoch * dataset_size + step + global_step = ms.Tensor(global_step, dtype=ms.int64) + + # NOTE: inputs must match the order in GeneratorWithLoss.construct + loss_ae_t, overflow, scaling_sens = training_step_ae(x, global_step) + + cur_global_step = epoch * dataset_size + step + 1 # starting from 1 for logging + if overflow: + logger.warning(f"Overflow occurs in step {cur_global_step}") + + # log + step_time = time.time() - start_time_s + if step % args.log_interval == 0: + loss_ae = float(loss_ae_t.asnumpy()) + logger.info(f"E: {epoch+1}, S: {step+1}, Loss ae: {loss_ae:.4f}, Step time: {step_time*1000:.2f}ms") + + epoch_cost = time.time() - start_time_e + per_step_time = epoch_cost / dataset_size + cur_epoch = epoch + 1 + logger.info( + f"Epoch:[{int(cur_epoch):>3d}/{int(args.epochs):>3d}], " + f"epoch time:{epoch_cost:.2f}s, per step time:{per_step_time*1000:.2f}ms, " + ) + if rank_id == 0: + if (cur_epoch % args.ckpt_save_interval == 0) or (cur_epoch == args.epochs): + ckpt_name = f"vae_kl_f8-e{cur_epoch}.ckpt" + if ema is not None: + ema.swap_before_eval() + + ckpt_manager.save(ae, None, ckpt_name=ckpt_name, append_dict=None) + if ema is not None: + ema.swap_after_eval() + + # TODO: eval while training + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/examples/opensora_hpcai/tools/annotate_vae_ucf101.py b/examples/opensora_hpcai/tools/annotate_vae_ucf101.py new file mode 100644 index 0000000000..b00ea352ba --- /dev/null +++ b/examples/opensora_hpcai/tools/annotate_vae_ucf101.py @@ -0,0 +1,38 @@ +import glob +import math +import os +import random + +root_dir = "UCF-101/" +train_ratio = 0.8 + +all_files = glob.glob(os.path.join(root_dir, "*/*.avi")) +num_samples = len(all_files) +print("Num samples: ", num_samples) + +# shuffle +# indices = list(range(num_samples)) +random.shuffle(all_files) + +# split +num_train = math.ceil(num_samples * train_ratio) +num_test = num_samples - num_train +train_set = sorted(all_files[:num_train]) +test_set = sorted(all_files[num_train:]) + + +# save csv +def save_csv(fns, save_path): + with open(save_path, "w") as fp: + fp.write("video\n") + for i, fn in enumerate(fns): + rel_path = fn.replace(root_dir, "") + if i != len(fns) - 1: + fp.write(f"{rel_path}\n") + else: + fp.write(f"{rel_path}") + + +save_csv(train_set, "ucf101_train.csv") +save_csv(test_set, "ucf101_test.csv") +print("Done. csv saved.") diff --git a/examples/opensora_hpcai/tools/convert_vae_3d.py b/examples/opensora_hpcai/tools/convert_vae_3d.py index 7974ca720f..3bc537f93f 100644 --- a/examples/opensora_hpcai/tools/convert_vae_3d.py +++ b/examples/opensora_hpcai/tools/convert_vae_3d.py @@ -14,14 +14,14 @@ def get_shape_from_str(shape): return shape -def convert(source_fp, target_fp, spatial_vae_only=False): +def convert(source_fp, target_fp, from_vae2d=False): # read param mapping files with open("tools/ms_pnames_vae1.2.txt") as file_ms: lines_ms = list(file_ms.readlines()) with open("tools/pt_pnames_vae1.2.txt") as file_pt: lines_pt = list(file_pt.readlines()) - if spatial_vae_only: + if from_vae2d: lines_ms = [line for line in lines_ms if line.startswith("spatial_vae")] lines_pt = [line for line in lines_pt if line.startswith("spatial_vae")] @@ -41,8 +41,13 @@ def convert(source_fp, target_fp, spatial_vae_only=False): shape_ms ), f"Mismatch param: PT: {name_pt}, {shape_pt} vs MS: {name_ms}, {shape_ms}" + if "from_vae2d": + name_pt = name_pt.replace("spatial_vae.module.", "") + data = sd_pt[name_pt].cpu().detach().numpy().reshape(shape_ms) - target_data.append({"name": name_ms, "data": ms.Tensor(data, dtype=ms.float32)}) + + data = ms.Tensor(input_data=data.astype(np.float32), dtype=ms.float32) + target_data.append({"name": name_ms, "data": data}) # ms.Tensor(data, dtype=ms.float32)}) print("Total params converted: ", len(target_data)) ms.save_checkpoint(target_data, target_fp) @@ -63,7 +68,7 @@ def convert(source_fp, target_fp, spatial_vae_only=False): type=str, help="Filename to save. Specify folder, e.g., ./models, or file path which ends with .ckpt, e.g., ./models/vae.ckpt", ) - parser.add_argument("--spatial_vae_only", action="store_true", help="only convert spatial vae, default: False") + parser.add_argument("--from_vae2d", action="store_true", help="only convert spatial vae, default: False") args = parser.parse_args() @@ -79,5 +84,5 @@ def convert(source_fp, target_fp, spatial_vae_only=False): if os.path.exists(target_fp): print(f"Warnings: {target_fp} will be overwritten!") - convert(args.src, target_fp, args.spatial_vae_only) + convert(args.src, target_fp, args.from_vae2d) print(f"Converted weight saved to {target_fp}") diff --git a/examples/opensora_hpcai/tools/plot.py b/examples/opensora_hpcai/tools/plot.py index d2b8b232aa..9f924c1abb 100644 --- a/examples/opensora_hpcai/tools/plot.py +++ b/examples/opensora_hpcai/tools/plot.py @@ -7,6 +7,9 @@ import argparse +import matplotlib + +matplotlib.use("Agg") import matplotlib.pyplot as plt import pandas as pd diff --git a/mindone/trainers/callback.py b/mindone/trainers/callback.py index f1d0d6f5ac..164fbd6e40 100755 --- a/mindone/trainers/callback.py +++ b/mindone/trainers/callback.py @@ -136,9 +136,10 @@ def choice_func(x): def on_train_step_end(self, run_context): cb_params = run_context.original_args() loss = _handle_loss(cb_params.net_outputs) - # cur_step = cb_params.cur_step_num + self.start_epoch * cb_params.batch_num opt = self._get_optimizer_from_cbp(cb_params) cur_step = int(opt.global_step.asnumpy().item()) + if cur_step <= 0: + cur_step = cb_params.cur_step_num + self.start_epoch * cb_params.batch_num step_num = (cb_params.batch_num * cb_params.epoch_num) if self.train_steps < 0 else self.train_steps