diff --git a/config.py b/config.py index b78e5a9b8..4c223875a 100644 --- a/config.py +++ b/config.py @@ -216,10 +216,12 @@ def create_parser(): help='Num of cycles for cosine decay and cyclic (default=1)') group.add_argument('--cycle_decay', type=float, default=1.0, help='Decay rate of lr max in each cosine cycle (default=1.0)') + group.add_argument('--layer_decay', type=float, default=None, + help='layer(model) decay rate of lr (default=None)') # Loss parameters group = parser.add_argument_group('Loss parameters') - group.add_argument('--loss', type=str, default='CE', choices=['BCE', 'CE'], + group.add_argument('--loss', type=str, default='CE', choices=['BCE', 'CE', 'None'], help='Type of loss, BCE (BinaryCrossEntropy) or CE (CrossEntropy) (default="CE")') group.add_argument('--label_smoothing', type=float, default=0.0, help='Use label smoothing (default=0.0)') @@ -270,6 +272,25 @@ def create_parser(): group.add_argument('--train_url', type=str, default='/cache/output/', help='model folder to save/load') + # pre-train + group = parser.add_argument_group('pre-train') + group.add_argument('--pretrain_resize', type=list, default=[224], + help='Crop the size of the image for pre-training.' + 'The length of list should be 2 if tokenizer is required. (default=[224])') + group.add_argument('--pretrain_interpolations', type=list, default=['bicubic', 'bilinear'], + help='Image interpolation mode for resize operator for pre-trainin') + group.add_argument('--tokenizer', type=str, default=None, + help='Name of tokenizer model for pre-train') + group.add_argument('--tokenizer_ckpt_path', type=str, default='', + help='Initialize tokenizer model from this checkpoint') + group.add_argument('--mask_type', type=str, default='random', + choices=['block_wise', 'patch_aligned', 'random'], + help='Type of mask generator') + group.add_argument('--mask_ratio', type=float, default=0.75, + help='Masking ratio') + group.add_argument('--mask_patch_size', type=int, default=32, + help='Size of mask patch') + return parser_config, parser # fmt: on diff --git a/configs/mae/mae_b_16_224_finetune_ascend.yaml b/configs/mae/mae_b_16_224_finetune_ascend.yaml new file mode 100644 index 000000000..7ddfcbcd9 --- /dev/null +++ b/configs/mae/mae_b_16_224_finetune_ascend.yaml @@ -0,0 +1,58 @@ +# system +mode: 0 +distribute: True +num_parallel_workers: 8 +val_while_train: True + +# dataset +dataset: "imagenet" +data_dir: "/path/to/imagenet" +shuffle: True +dataset_download: False +batch_size: 32 +drop_remainder: True + +# augmentation +image_resize: 224 +scale: [0.08, 1.0] +ratio: [0.75, 1.333] +hflip: 0.5 +interpolation: "bicubic" +auto_augment: "randaug-m9-mstd0.5-inc1" +re_prob: 0.25 +mixup: 0.8 +cutmix: 1.0 +re_value: "random" + +# model +model: "mae_b_16_224_finetune" +drop_rate: 0.0 +drop_path_rate: 0.1 +pretrained: False +ckpt_path: "" +keep_checkpoint_max: 10 +ckpt_save_dir: "./ckpt" +epoch_size: 100 +dataset_sink_mode: True +amp_level: "O2" + +# loss +loss: "CE" +loss_scale: 1024.0 +label_smoothing: 0.1 + +# lr scheduler +scheduler: "warmup_cosine_decay" +lr: 5e-4 +min_lr: 1e-6 +warmup_epochs: 5 +warmup_factor: 0 +decay_epochs: 95 +layer_decay: 0.65 +lr_epoch_stair: False + +# optimizer +opt: "adamw" +weight_decay: 0.05 +filter_bias_and_bn: True +use_nesterov: False diff --git a/configs/mae/mae_b_16_224_pretrain_ascend.yaml b/configs/mae/mae_b_16_224_pretrain_ascend.yaml new file mode 100644 index 000000000..601f64d60 --- /dev/null +++ b/configs/mae/mae_b_16_224_pretrain_ascend.yaml @@ -0,0 +1,57 @@ +# system +mode: 0 +distribute: True +num_parallel_workers: 8 + +# dataset +dataset: "imagenet" +data_dir: "/path/to/imagenet" +shuffle: True +dataset_download: False +batch_size: 64 +drop_remainder: True + +# augmentation +scale: [0.2, 1.0] +ratio: [0.75, 1.333] +hflip: 0.5 +color_jitter: [0.4, 0.4, 0.4] + +# model +model: "mae_b_16_224_pretrain" +drop_rate: 0.0 +drop_path_rate: 0.0 +pretrained: False +ckpt_path: "" +keep_checkpoint_max: 10 +ckpt_save_dir: "./ckpt" +epoch_size: 800 +dataset_sink_mode: True +amp_level: "O2" +clip_grad: True +clip_value: 3.0 + +# loss +loss: "None" +loss_scale: 1024.0 + +# lr scheduler +scheduler: "warmup_cosine_decay" +lr: 1.5e-4 +min_lr: 0 +warmup_epochs: 40 +warmup_factor: 0 +decay_epochs: 760 +lr_epoch_stair: False + +# optimizer +opt: "adamw" +weight_decay: 0.05 +filter_bias_and_bn: True +use_nesterov: False + +# pre-train +pretrain_resize: [224] +pretrain_interpolations: ["bicubic"] +mask_type: "random" +mask_ratio: 0.75 diff --git a/examples/ssl/finetune.py b/examples/ssl/finetune.py new file mode 100644 index 000000000..bff2b5f72 --- /dev/null +++ b/examples/ssl/finetune.py @@ -0,0 +1,332 @@ +""" Model training pipeline """ +import logging +import os +import sys + +mindcv_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) +sys.path.append(mindcv_path) + +import mindspore as ms +from mindspore import Tensor +from mindspore.communication import get_group_size, get_rank, init + +from mindcv.data import create_dataset, create_loader, create_transforms +from mindcv.loss import create_loss +from mindcv.models import create_model +from mindcv.optim import create_finetune_optimizer +from mindcv.scheduler import create_scheduler +from mindcv.utils import ( + AllReduceSum, + StateMonitor, + create_trainer, + get_metrics, + require_customized_train_step, + set_logger, + set_seed, +) + +from config import parse_args, save_args # isort: skip + +logger = logging.getLogger("mindcv.train") + + +def train(args): + """main train function""" + + ms.set_context(mode=args.mode) + if args.distribute: + init() + device_num = get_group_size() + rank_id = get_rank() + ms.set_auto_parallel_context( + device_num=device_num, + parallel_mode="data_parallel", + gradients_mean=True, + # we should but cannot set parameter_broadcast=True, which will cause error on gpu. + ) + else: + device_num = None + rank_id = None + + set_seed(args.seed) + set_logger(name="mindcv", output_dir=args.ckpt_save_dir, rank=rank_id, color=False) + logger.info( + "We recommend installing `termcolor` via `pip install termcolor` " + "and setup logger by `set_logger(..., color=True)`" + ) + + # create dataset + dataset_train = create_dataset( + name=args.dataset, + root=args.data_dir, + split=args.train_split, + shuffle=args.shuffle, + num_samples=args.num_samples, + num_shards=device_num, + shard_id=rank_id, + num_parallel_workers=args.num_parallel_workers, + download=args.dataset_download, + num_aug_repeats=args.aug_repeats, + ) + + if args.num_classes is None: + num_classes = dataset_train.num_classes() + else: + num_classes = args.num_classes + + # create transforms + transform_list = create_transforms( + dataset_name=args.dataset, + is_training=True, + image_resize=args.image_resize, + scale=args.scale, + ratio=args.ratio, + hflip=args.hflip, + vflip=args.vflip, + color_jitter=args.color_jitter, + interpolation=args.interpolation, + auto_augment=args.auto_augment, + mean=args.mean, + std=args.std, + re_prob=args.re_prob, + re_scale=args.re_scale, + re_ratio=args.re_ratio, + re_value=args.re_value, + re_max_attempts=args.re_max_attempts, + ) + + # load dataset + loader_train = create_loader( + dataset=dataset_train, + batch_size=args.batch_size, + drop_remainder=args.drop_remainder, + is_training=True, + mixup=args.mixup, + cutmix=args.cutmix, + cutmix_prob=args.cutmix_prob, + num_classes=num_classes, + transform=transform_list, + num_parallel_workers=args.num_parallel_workers, + ) + + if args.val_while_train: + dataset_eval = create_dataset( + name=args.dataset, + root=args.data_dir, + split=args.val_split, + num_shards=device_num, + shard_id=rank_id, + num_parallel_workers=args.num_parallel_workers, + download=args.dataset_download, + ) + + transform_list_eval = create_transforms( + dataset_name=args.dataset, + is_training=False, + image_resize=args.image_resize, + crop_pct=args.crop_pct, + interpolation=args.interpolation, + mean=args.mean, + std=args.std, + ) + + loader_eval = create_loader( + dataset=dataset_eval, + batch_size=args.batch_size, + drop_remainder=False, + is_training=False, + transform=transform_list_eval, + num_parallel_workers=args.num_parallel_workers, + ) + # validation dataset count + eval_count = dataset_eval.get_dataset_size() + if args.distribute: + all_reduce = AllReduceSum() + eval_count = all_reduce(Tensor(eval_count, ms.int32)) + else: + loader_eval = None + eval_count = None + + num_batches = loader_train.get_dataset_size() + # Train dataset count + train_count = dataset_train.get_dataset_size() + if args.distribute: + all_reduce = AllReduceSum() + train_count = all_reduce(Tensor(train_count, ms.int32)) + + # create model + network = create_model( + model_name=args.model, + num_classes=num_classes, + in_channels=args.in_channels, + drop_rate=args.drop_rate, + drop_path_rate=args.drop_path_rate, + pretrained=args.pretrained, + checkpoint_path=args.ckpt_path, + ema=args.ema, + ) + + num_params = sum([param.size for param in network.get_parameters()]) + + # create loss + loss = create_loss( + name=args.loss, + reduction=args.reduction, + label_smoothing=args.label_smoothing, + aux_factor=args.aux_factor, + ) + + # create learning rate schedule + lr_scheduler = create_scheduler( + num_batches, + scheduler=args.scheduler, + lr=args.lr, + min_lr=args.min_lr, + warmup_epochs=args.warmup_epochs, + warmup_factor=args.warmup_factor, + decay_epochs=args.decay_epochs, + decay_rate=args.decay_rate, + milestones=args.multi_step_decay_milestones, + num_epochs=args.epoch_size, + num_cycles=args.num_cycles, + cycle_decay=args.cycle_decay, + lr_epoch_stair=args.lr_epoch_stair, + ) + + # resume training if ckpt_path is given + if args.ckpt_path != "" and args.resume_opt: + opt_ckpt_path = os.path.join(args.ckpt_save_dir, f"optim_{args.model}.ckpt") + else: + opt_ckpt_path = "" + + # create optimizer + # TODO: consistent naming opt, name, dataset_name + if ( + args.loss_scale_type == "fixed" + and args.drop_overflow_update is False + and not require_customized_train_step(args.ema, args.clip_grad, args.gradient_accumulation_steps) + ): + optimizer_loss_scale = args.loss_scale + else: + optimizer_loss_scale = 1.0 + optimizer = create_finetune_optimizer( + network, + opt=args.opt, + lr=lr_scheduler, + weight_decay=args.weight_decay, + momentum=args.momentum, + nesterov=args.use_nesterov, + filter_bias_and_bn=args.filter_bias_and_bn, + loss_scale=optimizer_loss_scale, + checkpoint_path=opt_ckpt_path, + eps=args.eps, + scale=args.layer_decay, + ) + + # Define eval metrics. + metrics = get_metrics(num_classes) + + # create trainer + trainer = create_trainer( + network, + loss, + optimizer, + metrics, + amp_level=args.amp_level, + amp_cast_list=args.amp_cast_list, + loss_scale_type=args.loss_scale_type, + loss_scale=args.loss_scale, + drop_overflow_update=args.drop_overflow_update, + ema=args.ema, + ema_decay=args.ema_decay, + clip_grad=args.clip_grad, + clip_value=args.clip_value, + gradient_accumulation_steps=args.gradient_accumulation_steps, + ) + + # callback + # save checkpoint, summary training loss + # record val acc and do model selection if val dataset is available + begin_step = 0 + begin_epoch = 0 + if args.ckpt_path != "": + begin_step = optimizer.global_step.asnumpy()[0] + begin_epoch = args.ckpt_path.split("/")[-1].split("-")[1].split("_")[0] + begin_epoch = int(begin_epoch) + + summary_dir = f"./{args.ckpt_save_dir}/summary" + assert ( + args.ckpt_save_policy != "top_k" or args.val_while_train is True + ), "ckpt_save_policy is top_k, val_while_train must be True." + state_cb = StateMonitor( + trainer, + model_name=args.model, + model_ema=args.ema, + last_epoch=begin_epoch, + dataset_sink_mode=args.dataset_sink_mode, + dataset_val=loader_eval, + metric_name=list(metrics.keys()), + val_interval=args.val_interval, + ckpt_save_dir=args.ckpt_save_dir, + ckpt_save_interval=args.ckpt_save_interval, + ckpt_save_policy=args.ckpt_save_policy, + ckpt_keep_max=args.keep_checkpoint_max, + summary_dir=summary_dir, + log_interval=args.log_interval, + rank_id=rank_id, + device_num=device_num, + ) + + callbacks = [state_cb] + essential_cfg_msg = "\n".join( + [ + "Essential Experiment Configurations:", + f"MindSpore mode[GRAPH(0)/PYNATIVE(1)]: {args.mode}", + f"Distributed mode: {args.distribute}", + f"Number of devices: {device_num if device_num is not None else 1}", + f"Number of training samples: {train_count}", + f"Number of validation samples: {eval_count}", + f"Number of classes: {num_classes}", + f"Number of batches: {num_batches}", + f"Batch size: {args.batch_size}", + f"Auto augment: {args.auto_augment}", + f"MixUp: {args.mixup}", + f"CutMix: {args.cutmix}", + f"Model: {args.model}", + f"Model parameters: {num_params}", + f"Number of epochs: {args.epoch_size}", + f"Optimizer: {args.opt}", + f"Learning rate: {args.lr}", + f"LR Scheduler: {args.scheduler}", + f"Momentum: {args.momentum}", + f"Weight decay: {args.weight_decay}", + f"Auto mixed precision: {args.amp_level}", + f"Loss scale: {args.loss_scale}({args.loss_scale_type})", + ] + ) + logger.info(essential_cfg_msg) + save_args(args, os.path.join(args.ckpt_save_dir, f"{args.model}.yaml"), rank_id) + + if args.ckpt_path != "": + logger.info(f"Resume training from {args.ckpt_path}, last step: {begin_step}, last epoch: {begin_epoch}") + else: + logger.info("Start training") + + trainer.train(args.epoch_size, loader_train, callbacks=callbacks, dataset_sink_mode=args.dataset_sink_mode) + + +if __name__ == "__main__": + args = parse_args() + + # data sync for cloud platform if enabled + if args.enable_modelarts: + import moxing as mox + + args.data_dir = f"/cache/{args.data_url}" + mox.file.copy_parallel(src_url=os.path.join(args.data_url, args.dataset), dst_url=args.data_dir) + + # core training + train(args) + + if args.enable_modelarts: + mox.file.copy_parallel(src_url=args.ckpt_save_dir, dst_url=args.train_url) diff --git a/examples/ssl/pretrain.py b/examples/ssl/pretrain.py new file mode 100644 index 000000000..7eadcde73 --- /dev/null +++ b/examples/ssl/pretrain.py @@ -0,0 +1,281 @@ +""" Model pre-training pipeline """ +import logging +import os +import sys + +mindcv_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) +sys.path.append(mindcv_path) + +import mindspore as ms +from mindspore import Tensor +from mindspore.communication import get_group_size, get_rank, init + +from mindcv.data import create_dataset, create_loader_pretrain, create_transforms_pretrain +from mindcv.loss import create_loss +from mindcv.models import create_model +from mindcv.optim import create_pretrain_optimizer +from mindcv.scheduler import create_scheduler +from mindcv.utils import AllReduceSum, StateMonitor, create_trainer, require_customized_train_step, set_logger, set_seed + +from config import parse_args, save_args # isort: skip + +logger = logging.getLogger("mindcv.pre-train") + + +def train(args): + """main train function""" + + ms.set_context(mode=args.mode) + if args.distribute: + init() + device_num = get_group_size() + rank_id = get_rank() + ms.set_auto_parallel_context( + device_num=device_num, + parallel_mode="data_parallel", + gradients_mean=True, + # we should but cannot set parameter_broadcast=True, which will cause error on gpu. + ) + else: + device_num = None + rank_id = None + + set_seed(args.seed) + set_logger(name="mindcv", output_dir=args.ckpt_save_dir, rank=rank_id, color=False) + logger.info( + "We recommend installing `termcolor` via `pip install termcolor` " + "and setup logger by `set_logger(..., color=True)`" + ) + + # create dataset + dataset_train = create_dataset( + name=args.dataset, + root=args.data_dir, + split=args.train_split, + shuffle=args.shuffle, + num_samples=args.num_samples, + num_shards=device_num, + shard_id=rank_id, + num_parallel_workers=args.num_parallel_workers, + download=args.dataset_download, + num_aug_repeats=args.aug_repeats, + ) + + # create transforms + patch_size = int(args.model.split("_")[2]) # need to be more robust + transform_list = create_transforms_pretrain( + dataset_name=args.dataset, + resize_list=args.pretrain_resize, + tokenizer=args.tokenizer, + scale=args.scale, + ratio=args.ratio, + hflip=args.hflip, + color_jitter=args.color_jitter, + interpolations=args.pretrain_interpolations.copy(), + mean=args.mean, + std=args.std, + mask_type=args.mask_type, + mask_ratio=args.mask_ratio, + patch_size=patch_size, + mask_patch_size=args.mask_patch_size, + ) + + # load dataset + loader_train = create_loader_pretrain( + dataset=dataset_train, + batch_size=args.batch_size, + drop_remainder=args.drop_remainder, + transform=transform_list, + num_parallel_workers=args.num_parallel_workers, + ) + + loader_eval = None + + num_batches = loader_train.get_dataset_size() + # Train dataset count + train_count = dataset_train.get_dataset_size() + if args.distribute: + all_reduce = AllReduceSum() + train_count = all_reduce(Tensor(train_count, ms.int32)) + + # create model + network = create_model( + model_name=args.model, + drop_rate=args.drop_rate, + drop_path_rate=args.drop_path_rate, + mask_ratio=args.mask_ratio, + pretrained=args.pretrained, + checkpoint_path=args.ckpt_path, + ema=args.ema, + ) + + if args.tokenizer is not None: + tokenizer = create_model(model_name=args.tokenizer, checkpoint_path=args.tokenizer_ckpt_path) + else: + tokenizer = None + + num_params = sum([param.size for param in network.get_parameters()]) + + # create loss + if args.loss != "None": + loss = create_loss( + name=args.loss, + reduction=args.reduction, + label_smoothing=args.label_smoothing, + aux_factor=args.aux_factor, + ) + else: + loss = None + + # create learning rate schedule + lr_scheduler = create_scheduler( + num_batches, + scheduler=args.scheduler, + lr=args.lr, + min_lr=args.min_lr, + warmup_epochs=args.warmup_epochs, + warmup_factor=args.warmup_factor, + decay_epochs=args.decay_epochs, + decay_rate=args.decay_rate, + milestones=args.multi_step_decay_milestones, + num_epochs=args.epoch_size, + lr_epoch_stair=args.lr_epoch_stair, + num_cycles=args.num_cycles, + cycle_decay=args.cycle_decay, + ) + + # resume training if ckpt_path is given + if args.ckpt_path != "" and args.resume_opt: + opt_ckpt_path = os.path.join(args.ckpt_save_dir, f"optim_{args.model}.ckpt") + else: + opt_ckpt_path = "" + + # create optimizer + # TODO: consistent naming opt, name, dataset_name + if ( + args.loss_scale_type == "fixed" + and args.drop_overflow_update is False + and not require_customized_train_step(args.ema, args.clip_grad, args.gradient_accumulation_steps) + ): + optimizer_loss_scale = args.loss_scale + else: + optimizer_loss_scale = 1.0 + optimizer = create_pretrain_optimizer( + network, + opt=args.opt, + lr=lr_scheduler, + weight_decay=args.weight_decay, + momentum=args.momentum, + nesterov=args.use_nesterov, + filter_bias_and_bn=args.filter_bias_and_bn, + loss_scale=optimizer_loss_scale, + checkpoint_path=opt_ckpt_path, + eps=args.eps, + ) + + # Define eval metrics. + metrics = None + + # create trainer + trainer = create_trainer( + network, + loss, + optimizer, + metrics, + amp_level=args.amp_level, + amp_cast_list=args.amp_cast_list, + loss_scale_type=args.loss_scale_type, + loss_scale=args.loss_scale, + drop_overflow_update=args.drop_overflow_update, + ema=args.ema, + ema_decay=args.ema_decay, + clip_grad=args.clip_grad, + clip_value=args.clip_value, + gradient_accumulation_steps=args.gradient_accumulation_steps, + tokenizer=tokenizer, + ) + + # callback + # save checkpoint, summary training loss + # record val acc and do model selection if val dataset is available + begin_step = 0 + begin_epoch = 0 + if args.ckpt_path != "": + begin_step = optimizer.global_step.asnumpy()[0] + begin_epoch = args.ckpt_path.split("/")[-1].split("-")[1].split("_")[0] + begin_epoch = int(begin_epoch) + + summary_dir = f"./{args.ckpt_save_dir}/summary" + assert ( + args.ckpt_save_policy != "top_k" or args.val_while_train is True + ), "ckpt_save_policy is top_k, val_while_train must be True." + state_cb = StateMonitor( + trainer, + model_name=args.model, + model_ema=args.ema, + last_epoch=begin_epoch, + dataset_sink_mode=args.dataset_sink_mode, + dataset_val=loader_eval, + metric_name=[], + val_interval=args.val_interval, + ckpt_save_dir=args.ckpt_save_dir, + ckpt_save_interval=args.ckpt_save_interval, + ckpt_save_policy=args.ckpt_save_policy, + ckpt_keep_max=args.keep_checkpoint_max, + summary_dir=summary_dir, + log_interval=args.log_interval, + rank_id=rank_id, + device_num=device_num, + ) + + callbacks = [state_cb] + essential_cfg_msg = "\n".join( + [ + "Essential Experiment Configurations:", + f"MindSpore mode[GRAPH(0)/PYNATIVE(1)]: {args.mode}", + f"Distributed mode: {args.distribute}", + f"Number of devices: {device_num if device_num is not None else 1}", + f"Number of training samples: {train_count}", + f"Number of batches: {num_batches}", + f"Batch size: {args.batch_size}", + f"Auto augment: {args.auto_augment}", + f"MixUp: {args.mixup}", + f"CutMix: {args.cutmix}", + f"Model: {args.model}", + f"Model parameters: {num_params}", + f"Number of epochs: {args.epoch_size}", + f"Optimizer: {args.opt}", + f"Learning rate: {args.lr}", + f"LR Scheduler: {args.scheduler}", + f"Momentum: {args.momentum}", + f"Weight decay: {args.weight_decay}", + f"Auto mixed precision: {args.amp_level}", + f"Loss scale: {args.loss_scale}({args.loss_scale_type})", + ] + ) + logger.info(essential_cfg_msg) + save_args(args, os.path.join(args.ckpt_save_dir, f"{args.model}.yaml"), rank_id) + + if args.ckpt_path != "": + logger.info(f"Resume training from {args.ckpt_path}, last step: {begin_step}, last epoch: {begin_epoch}") + else: + logger.info("Start training") + + trainer.train(args.epoch_size, loader_train, callbacks=callbacks, dataset_sink_mode=args.dataset_sink_mode) + + +if __name__ == "__main__": + args = parse_args() + + # data sync for cloud platform if enabled + if args.enable_modelarts: + import moxing as mox + + args.data_dir = f"/cache/{args.data_url}" + mox.file.copy_parallel(src_url=os.path.join(args.data_url, args.dataset), dst_url=args.data_dir) + + # core training + train(args) + + if args.enable_modelarts: + mox.file.copy_parallel(src_url=args.ckpt_save_dir, dst_url=args.train_url) diff --git a/mindcv/data/__init__.py b/mindcv/data/__init__.py index 4ecf5e5d0..02a91024b 100644 --- a/mindcv/data/__init__.py +++ b/mindcv/data/__init__.py @@ -1,12 +1,21 @@ """ Data processing """ -from . import dataset_download, dataset_factory, loader, transforms_factory +from . import ( + dataset_download, + dataset_factory, + loader, + pretrain_loader, + pretrain_transforms_factory, + transforms_factory, +) from .auto_augment import * from .constants import * from .dataset_download import * from .dataset_factory import * from .loader import * +from .pretrain_loader import * +from .pretrain_transforms_factory import * from .transforms_factory import * __all__ = [] @@ -14,3 +23,5 @@ __all__.extend(dataset_factory.__all__) __all__.extend(loader.__all__) __all__.extend(transforms_factory.__all__) +__all__.extend(pretrain_loader.__all__) +__all__.extend(pretrain_transforms_factory.__all__) diff --git a/mindcv/data/mask_generator/__init__.py b/mindcv/data/mask_generator/__init__.py new file mode 100644 index 000000000..bf55764ba --- /dev/null +++ b/mindcv/data/mask_generator/__init__.py @@ -0,0 +1,5 @@ +from . import mask_factory +from .mask_factory import create_mask_generator + +__all__ = [] +__all__.extend(mask_factory.__all__) diff --git a/mindcv/data/mask_generator/block_wise_mask.py b/mindcv/data/mask_generator/block_wise_mask.py new file mode 100644 index 000000000..f6fc00cb3 --- /dev/null +++ b/mindcv/data/mask_generator/block_wise_mask.py @@ -0,0 +1,73 @@ +import math +import random +from typing import Optional, Tuple + +import numpy as np + + +class BlockWiseMaskGenerator: + def __init__( + self, + input_size: int = 224, + model_patch_size: int = 16, + mask_ratio: float = 0.4, + min_num_patches: int = 4, + max_num_patches: Optional[int] = None, + min_aspect: int = 0.3, + max_aspect: Optional[int] = None, + ): + assert input_size % model_patch_size == 0 + + grid_size = input_size // model_patch_size + self.height, self.width = (grid_size, grid_size) + + num_masking_patches = int(np.ceil(grid_size**2 * mask_ratio)) + self.num_masking_patches = num_masking_patches + + self.min_num_patches = min_num_patches + self.max_num_patches = num_masking_patches if max_num_patches is None else max_num_patches + + max_aspect = max_aspect or 1 / min_aspect + self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) + + def _get_shape(self) -> Tuple[int, int]: + return self.height, self.width + + def _mask(self, mask: np.ndarray, max_mask_patches: int): + delta = 0 + for _ in range(10): + target_area = random.uniform(self.min_num_patches, max_mask_patches) + aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) + h = int(round(math.sqrt(target_area * aspect_ratio))) + w = int(round(math.sqrt(target_area / aspect_ratio))) + if w < self.width and h < self.height: + top = random.randint(0, self.height - h) + left = random.randint(0, self.width - w) + + num_masked = mask[top : top + h, left : left + w].sum() + # Overlap + if 0 < h * w - num_masked <= max_mask_patches: + for i in range(top, top + h): + for j in range(left, left + w): + if mask[i, j] == 0: + mask[i, j] = 1 + delta += 1 + + if delta > 0: + break + return delta + + def __call__(self) -> np.ndarray: + mask = np.zeros(shape=self._get_shape(), dtype=np.int32) + mask_count = 0 + while mask_count < self.num_masking_patches: + max_mask_patches = self.num_masking_patches - mask_count + max_mask_patches = min(max_mask_patches, self.max_num_patches) + + delta = self._mask(mask, max_mask_patches) + if delta == 0: + break + else: + mask_count += delta + + return mask diff --git a/mindcv/data/mask_generator/mask_factory.py b/mindcv/data/mask_generator/mask_factory.py new file mode 100644 index 000000000..de632e9d5 --- /dev/null +++ b/mindcv/data/mask_generator/mask_factory.py @@ -0,0 +1,20 @@ +from .block_wise_mask import BlockWiseMaskGenerator +from .patch_aligned_mask import PatchAlignedMaskGenerator +from .random_mask import RandomMaskGenerator + +__all__ = ["create_mask_generator"] + + +def create_mask_generator( + mask_name: str, input_size: int = 224, patch_size: int = 16, mask_ratio: float = 0.6, **kwargs +): + if mask_name == "random": + mask_generator = RandomMaskGenerator(input_size, patch_size, mask_ratio) + elif mask_name == "block_wise": + mask_generator = BlockWiseMaskGenerator(input_size, patch_size, mask_ratio) + elif mask_name == "patch_aligned": + mask_generator = PatchAlignedMaskGenerator(input_size, patch_size, mask_ratio, **kwargs) + else: + raise NotImplementedError(f"{mask_name} mask generator is not implemented.") + + return mask_generator diff --git a/mindcv/data/mask_generator/patch_aligned_mask.py b/mindcv/data/mask_generator/patch_aligned_mask.py new file mode 100644 index 000000000..651a43da5 --- /dev/null +++ b/mindcv/data/mask_generator/patch_aligned_mask.py @@ -0,0 +1,25 @@ +import numpy as np + + +class PatchAlignedMaskGenerator: + def __init__( + self, input_size: int = 192, model_patch_size: int = 4, mask_ratio: float = 0.6, mask_patch_size: int = 32 + ): + assert input_size % mask_patch_size == 0 + assert mask_patch_size % model_patch_size == 0 + + self.rand_size = input_size // mask_patch_size + self.scale = mask_patch_size // model_patch_size + + self.token_count = self.rand_size**2 + self.mask_count = int(np.ceil(self.token_count * mask_ratio)) + + def __call__(self): + mask_idx = np.random.permutation(self.token_count)[: self.mask_count] + mask = np.zeros(self.token_count, dtype=np.int32) + mask[mask_idx] = 1 + + mask = mask.reshape((self.rand_size, self.rand_size)) + mask = mask.repeat(self.scale, axis=0).repeat(self.scale, axis=1) + + return mask diff --git a/mindcv/data/mask_generator/random_mask.py b/mindcv/data/mask_generator/random_mask.py new file mode 100644 index 000000000..1077e54dc --- /dev/null +++ b/mindcv/data/mask_generator/random_mask.py @@ -0,0 +1,18 @@ +import numpy as np + + +class RandomMaskGenerator: + def __init__(self, input_size: int = 224, model_patch_size: int = 16, mask_ratio: float = 0.75): + assert input_size % model_patch_size == 0 + + self.grid_size = input_size // model_patch_size + self.seq_len = self.grid_size**2 + self.mask_count = int(np.ceil(self.seq_len * mask_ratio)) + + def __call__(self): + mask_idx = np.random.permutation(self.seq_len)[: self.mask_count] + mask = np.zeros(self.seq_len, dtype=np.int32) + mask[mask_idx] = 1 + + mask = mask.reshape((self.grid_size, self.grid_size)) + return mask diff --git a/mindcv/data/pretrain_loader.py b/mindcv/data/pretrain_loader.py new file mode 100644 index 000000000..2f1e4dd1d --- /dev/null +++ b/mindcv/data/pretrain_loader.py @@ -0,0 +1,32 @@ +""" +Create dataloader for pre-training +""" +import inspect + +__all__ = ["create_loader_pretrain"] + + +def create_loader_pretrain( + dataset, batch_size, drop_remainder=False, transform=None, num_parallel_workers=None, python_multiprocessing=False +): + if transform is None: + raise ValueError("tranform should not be None for pre-training.") + + # notes: mindspore-2.0 delete parameter 'column_order' + sig = inspect.signature(dataset.map) + pass_column_order = False if "kwargs" in sig.parameters else True + + dataset = dataset.map( + operations=transform, + input_columns="image", + output_columns=transform.output_columns, + column_order=transform.output_columns if pass_column_order else None, + num_parallel_workers=num_parallel_workers, + python_multiprocessing=python_multiprocessing, + ) + if not pass_column_order: + dataset = dataset.project(transform.output_columns) + + dataset = dataset.batch(batch_size=batch_size, drop_remainder=drop_remainder) + + return dataset diff --git a/mindcv/data/pretrain_transforms_factory.py b/mindcv/data/pretrain_transforms_factory.py new file mode 100644 index 000000000..7129ad8e4 --- /dev/null +++ b/mindcv/data/pretrain_transforms_factory.py @@ -0,0 +1,127 @@ +""" +Transform operation for pre-training +""" + +from typing import List, Tuple, Union + +from mindspore.dataset import vision +from mindspore.dataset.transforms import Compose +from mindspore.dataset.vision import Inter + +from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .mask_generator import create_mask_generator + +__all__ = ["create_transforms_pretrain"] + + +class RandomResizedCropWithTwoResolution: + def __init__(self, resize_list: List, interpolations: Union[List, Tuple], scale=(0.08, 1.0), ratio=(0.75, 1.333)): + self.first_transform = vision.RandomResizedCrop(resize_list[0], scale, ratio, interpolations[0]) + self.second_transform = vision.RandomResizedCrop(resize_list[1], scale, ratio, interpolations[1]) + + def __call__(self, img): + return self.first_transform(img), self.second_transform(img) + + +class TransformsForPretrain: + def __init__( + self, + resize_list: List = [224], + tokenizer: str = "dall-e", + mask_type: str = "block-wise", + scale=(0.08, 1.0), + ratio=(0.75, 1.333), + hflip=0.5, + color_jitter=None, + interpolations: Union[List, Tuple] = ["bicubic", "bilinear"], # lanczos is not implemented in MindSpore + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, + patch_size: int = 16, + mask_ratio: float = 0.4, + **kwargs + ): + for i in range(len(interpolations)): + if hasattr(Inter, interpolations[i].upper()): + interpolations[i] = getattr(Inter, interpolations[i].upper()) + else: + interpolations[i] = Inter.BILINEAR + + if len(resize_list) == 2: + common_transform = [vision.Decode()] + if color_jitter is not None: + if isinstance(color_jitter, (list, tuple)): + # color jitter shoulf be a 3-tuple/list for brightness/contrast/saturation + # or 4 if also augmenting hue + assert len(color_jitter) in (3, 4) + else: + color_jitter = (float(color_jitter),) * 3 + common_transform += [vision.RandomColorAdjust(*color_jitter)] + + if hflip > 0.0: + common_transform += [vision.RandomHorizontalFlip(prob=hflip)] + + common_transform += [RandomResizedCropWithTwoResolution(resize_list, interpolations, scale, ratio)] + self.common_transform = Compose(common_transform) + + self.patch_transform = Compose([vision.Normalize(mean=mean, std=std), vision.HWC2CHW()]) + + if tokenizer == "dall_e": # beit + self.visual_token_transform = Compose([vision.ToTensor(), lambda x: (1 - 2 * 0.1) * x + 0.1]) + elif tokenizer == "vqkd": # beit v2 + self.visual_token_transform = Compose([vision.ToTensor()]) + elif tokenizer == "clip": # eva, eva-02 + self.visual_token_transform = Compose( + [ + vision.ToTensor(), + vision.Normalize( + mean=(0.48145466, 0.4578275, 0.40821073), + std=(0.26862954, 0.26130258, 0.27577711), + is_hwc=False, + ), + ] + ) + + self.masked_position_generator = create_mask_generator( + mask_type, input_size=resize_list[0], patch_size=patch_size, mask_ratio=mask_ratio, **kwargs + ) + + self.output_columns = ["patch", "token", "mask"] + else: + self.common_transform = None + + patch_transform = [ + vision.RandomCropDecodeResize( + size=resize_list[0], scale=scale, ratio=ratio, interpolation=interpolations[0] + ) + ] + + if hflip > 0.0: + patch_transform += [vision.RandomHorizontalFlip(hflip)] + + patch_transform += [vision.Normalize(mean=mean, std=std), vision.HWC2CHW()] + self.patch_transform = Compose(patch_transform) + + self.masked_position_generator = create_mask_generator( + mask_type, input_size=resize_list[0], patch_size=patch_size, mask_ratio=mask_ratio, **kwargs + ) + + self.output_columns = ["patch", "mask"] + + def __call__(self, image): + if self.common_transform is not None: # for beit, beit v2, eva, eva-02 + patches, visual_tokens = self.common_transform(image) + patches = self.patch_transform(patches) + visual_tokens = self.visual_token_transform(visual_tokens) + masks = self.masked_position_generator() + return patches, visual_tokens, masks + else: + patches = self.patch_transform(image) # for MAE, SimMIM + masks = self.masked_position_generator() + return patches, masks + + +def create_transforms_pretrain(dataset_name="", **kwargs): + if dataset_name in ("imagenet", ""): + return TransformsForPretrain(**kwargs) + else: + raise NotImplementedError() diff --git a/mindcv/models/__init__.py b/mindcv/models/__init__.py index 75fa2f6db..50d8ab979 100644 --- a/mindcv/models/__init__.py +++ b/mindcv/models/__init__.py @@ -17,6 +17,7 @@ inceptionv3, inceptionv4, layers, + mae, mixnet, mlpmixer, mnasnet, @@ -72,6 +73,7 @@ from .inceptionv3 import * from .inceptionv4 import * from .layers import * +from .mae import * from .mixnet import * from .mlpmixer import * from .mnasnet import * @@ -129,6 +131,7 @@ __all__.extend(["InceptionV3", "inception_v3"]) __all__.extend(["InceptionV4", "inception_v4"]) __all__.extend(layers.__all__) +__all__.extend(mae.__all__) __all__.extend(mixnet.__all__) __all__.extend(mlpmixer.__all__) __all__.extend(mnasnet.__all__) diff --git a/mindcv/models/helpers.py b/mindcv/models/helpers.py index c7e02a033..24b68513e 100644 --- a/mindcv/models/helpers.py +++ b/mindcv/models/helpers.py @@ -9,8 +9,11 @@ from itertools import repeat from typing import Callable, Dict, List, Optional +import numpy as np +from scipy import interpolate + import mindspore.nn as nn -from mindspore import load_checkpoint, load_param_into_net +from mindspore import Parameter, Tensor, load_checkpoint, load_param_into_net, ops from ..utils.download import DownLoad, get_default_download_root from .features import FeatureExtractWrapper @@ -71,7 +74,8 @@ def load_pretrained(model, default_cfg, num_classes=1000, in_channels=3, filter_ if filter_fn is not None: param_dict = filter_fn(param_dict) - load_param_into_net(model, param_dict) + strict_load = default_cfg.get("strict_load", False) + load_param_into_net(model, param_dict, strict_load) def make_divisible( @@ -198,3 +202,105 @@ def build_model_with_cfg( raise RuntimeError(f"`feature_only` is not implemented for `{model_cls.__name__}` model.") from e return model + + +def interpolate_relative_position_bias(checkpoint_params, network): + if "rel_pos_bias.relative_position_bias_table" in checkpoint_params \ + and isinstance(network.rel_pos_bias, nn.CellList): + + num_layers = network.get_num_layers() + rel_pos_bias = checkpoint_params["rel_pos_bias.relative_position_bias_table"] + for i in range(num_layers): + checkpoint_params[f"rel_pos_bias.{i}.relative_position_bias_table"] = rel_pos_bias.clone() + checkpoint_params.pop("rel_pos_bias.relative_position_bias_table") + + elif "rel_pos_bias.0.relative_position_bias_table" in checkpoint_params \ + and not isinstance(network.rel_pos_bias, nn.CellList) \ + and isinstance(network.rel_pos_bias, nn.Cell): + + raise NotImplementedError("Converting multiple relative position bias to one is not supported.") + + all_keys = list(checkpoint_params.keys()) + for key in all_keys: + if "relative_position_index" in key: + checkpoint_params.pop(key) + + if "relative_position_bias_table" in key: + bias_table = checkpoint_params[key] + src_num_pos, num_attn_heads = bias_table.shape + dst_num_pos, _ = network.parameters_dict()[key].shape + dst_patch_shape = network.patch_embed.patches_resolution + if dst_patch_shape[0] != dst_patch_shape[1]: + raise NotImplementedError("Unsquared patch is not supported.") + + num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * (dst_patch_shape[1] * 2 - 1) + src_size = int((src_num_pos - num_extra_tokens) ** 0.5) + dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5) + if src_size != dst_size: + print("Position interpolate for %s from %dx%d to %dx%d" % (key, src_size, src_size, dst_size, dst_size)) + extra_tokens = bias_table[-num_extra_tokens:, :] + rel_pos_bias = bias_table[:-num_extra_tokens, :] + + def geometric_progression(a, r, n): + return a * (1.0 - r ** n) / (1.0 - r) + + left, right = 1.01, 1.5 + while right - left > 1e-6: + q = (left + right) / 2.0 + gp = geometric_progression(1, q, src_size // 2) + if gp > dst_size // 2: + right = q + else: + left = q + + dis = [] + cur = 1 + for i in range(src_size // 2): + dis.append(cur) + cur += q ** (i + 1) + + r_ids = [-_ for _ in reversed(dis)] + x = r_ids + [0] + dis + y = r_ids + [0] + dis + + t = dst_size // 2.0 + dx = np.arange(-t, t + 0.1, 1.0) + dy = np.arange(-t, t + 0.1, 1.0) + + all_rel_pos_bias = [] + for i in range(num_attn_heads): + z = ops.reshape(rel_pos_bias[:, i], (src_size, src_size)).asnumpy() + f = interpolate.interp2d(x, y, z, kind="cubic") + all_rel_pos_bias.append(ops.reshape(Tensor(f(dx, dy), dtype=bias_table.dtype), (-1, 1))) + + rel_pos_bias = ops.concat(all_rel_pos_bias, axis=-1) + new_rel_pos_bias = ops.concat((rel_pos_bias, extra_tokens), axis=0) + checkpoint_params[key] = Parameter(new_rel_pos_bias) + + return checkpoint_params + + +def interpolate_pos_embed(checkpoint_params, network): + pos_embed_checkpoint = checkpoint_params["pos_embed"] + embedding_size = pos_embed_checkpoint.shape[-1] + num_patches = network.patch_embed.num_patches + num_extra_tokens = network.pos_embed.shape[-2] - num_patches + + orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) + new_size = int(num_patches ** 0.5) + # class_token and dist_token are kept unchanged + if orig_size != new_size: + print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = ops.reshape(pos_tokens, (-1, orig_size, orig_size, embedding_size)) + pos_tokens = ops.transpose(pos_tokens, (0, 3, 1, 2)) + pos_tokens = ops.interpolate(pos_tokens, size=(new_size, new_size), + mode='bicubic', align_corners=False) # require MindSpore 2.0 + pos_tokens = ops.transpose(pos_tokens, (0, 2, 3, 1)) + pos_tokens = ops.reshape(pos_tokens, (-1, new_size * new_size, embedding_size)) + new_pos_embed = ops.concat((extra_tokens, pos_tokens), axis=1) + checkpoint_params['pos_embed'] = Parameter(new_pos_embed) + + return checkpoint_params diff --git a/mindcv/models/mae.py b/mindcv/models/mae.py new file mode 100644 index 000000000..18b1c5233 --- /dev/null +++ b/mindcv/models/mae.py @@ -0,0 +1,431 @@ +from functools import partial +from typing import Optional + +import numpy as np + +import mindspore as ms +from mindspore import Parameter, Tensor, nn, ops +from mindspore.common.initializer import Normal, initializer + +from .helpers import load_pretrained +from .registry import register_model +from .vit_encoder import Block, VisionTransformerEncoder + +__all__ = [ + "mae_b_16_224_pretrain", + "mae_l_16_224_pretrain", + "mae_h_16_224_pretrain", + "mae_b_16_224_finetune", + "mae_l_16_224_finetune", + "mae_h_14_224_finetune" +] + + +def _cfg(url="", strict_load=False, **kwargs): + return { + "url": url, + "num_classes": 1000, + "input_size": (3, 224, 224), + "first_conv": "patch_embed.proj", + "classifier": "head", + "strict_load": strict_load, + **kwargs, + } + + +default_cfgs = { + "mae_b_16_224_finetune": _cfg(url="", strict_load=True), + "mae_l_16_224_finetune": _cfg(url=""), + "mae_h_14_224_finetune": _cfg(url=""), +} + + +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float32) + omega /= embed_dim / 2. + omega = 1. / 10000 ** omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +class MAEForPretrain(VisionTransformerEncoder): + def __init__( + self, + img_size: int = 224, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 1024, + depth: int = 24, + num_heads: int = 16, + mlp_ratio: float = 4., + decoder_embed_dim: int = 512, + decoder_depth: int = 8, + decoder_num_heads: int = 16, + act_layer: nn.Cell = nn.GELU, + norm_layer: nn.Cell = nn.LayerNorm, + norm_pix_loss: bool = True, + mask_ratio: float = 0.75, + **kwargs + ): + super(MAEForPretrain, self).__init__( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + depth=depth, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + init_values=None, + act_layer=act_layer, + norm_layer=norm_layer, + use_abs_pos_emb=True, + use_rel_pos_bias=False, + use_shared_rel_pos_bias=False, + **kwargs + ) + self.cls_token = Parameter(initializer(Normal(sigma=0.02), (1, 1, embed_dim))) + + self.unmask_len = int(np.floor(self.num_patches * (1 - mask_ratio))) + + encoder_pos_emb = Tensor(get_2d_sincos_pos_embed( + embed_dim, int(self.num_patches ** 0.5), cls_token=True), ms.float32 + ) + encoder_pos_emb = ops.expand_dims(encoder_pos_emb, axis=0) + self.pos_embed = Parameter(encoder_pos_emb, requires_grad=False) + self.norm = norm_layer((embed_dim,)) + + self.decoder_embed = nn.Dense(embed_dim, decoder_embed_dim) + self.mask_token = Parameter(initializer(Normal(sigma=0.02), (1, 1, decoder_embed_dim))) + + decoder_pos_emb = Tensor(get_2d_sincos_pos_embed( + decoder_embed_dim, int(self.num_patches ** 0.5), cls_token=True), ms.float32 + ) + decoder_pos_emb = ops.expand_dims(decoder_pos_emb, axis=0) + self.decoder_pos_embed = Parameter(decoder_pos_emb, requires_grad=False) + + self.decoder_blocks = nn.CellList([ + Block( + dim=decoder_embed_dim, num_heads=decoder_num_heads, qkv_bias=True, + mlp_ratio=mlp_ratio, init_values=None, act_layer=act_layer, norm_layer=norm_layer, + ) for _ in range(decoder_depth) + ]) + self.decoder_norm = norm_layer((decoder_embed_dim,)) + self.decoder_head = nn.Dense(decoder_embed_dim, patch_size ** 2 * in_chans) + + self.sort = ops.Sort() + + self.norm_pix_loss = norm_pix_loss + self._init_weights() + + def _init_weights(self): + for name, cell in self.cells_and_names(): + if isinstance(cell, nn.Dense): + cell.weight.set_data( + initializer("xavier_uniform", cell.weight.shape, cell.weight.dtype) + ) + if cell.bias is not None: + cell.bias.set_data( + initializer('zeros', cell.bias.shape, cell.bias.dtype) + ) + + elif isinstance(cell, nn.LayerNorm): + cell.gamma.set_data( + initializer('ones', cell.gamma.shape, cell.gamma.dtype) + ) + cell.beta.set_data( + initializer('zeros', cell.beta.shape, cell.beta.dtype) + ) + + if name == "patch_embed.proj": + cell.weight.set_data( + initializer("xavier_uniform", cell.weight.shape, cell.weight.dtype) + ) + + def patchify(self, imgs): + """ + imgs: (N, 3, H, W) + x: (N, L, patch_size ** 2 * 3) + """ + N, _, H, W = imgs.shape + p = self.patch_embed.patch_size[0] + assert H == W and H % p == 0 + h = w = H // p + + x = ops.reshape(imgs, (N, 3, h, p, w, p)) + x = ops.transpose(x, (0, 2, 4, 3, 5, 1)) + x = ops.reshape(x, (N, h * w, p ** 2 * 3)) + return x + + def unpatchify(self, x): + """ + x: (N, L, patch_size ** 2 * 3) + imgs: (N, 3, H, W) + """ + N, L, _ = x.shape + p = self.patch_embed.patch_size[0] + h = w = int(L ** 0.5) + assert h * w == L + + imgs = ops.reshape(x, (N, h, w, p, p, 3)) + imgs = ops.transpose(imgs, (0, 5, 1, 3, 2, 4)) + imgs = ops.reshape(imgs, (N, 3, h * p, w * p)) + return imgs + + def apply_masking(self, x, mask): + D = x.shape[2] + _, ids_shuffle = self.sort(mask.astype(ms.float32)) + _, ids_restore = self.sort(ids_shuffle.astype(ms.float32)) + + ids_keep = ids_shuffle[:, :self.unmask_len] + ids_keep = ops.broadcast_to(ops.expand_dims(ids_keep, axis=-1), (-1, -1, D)) + x_unmasked = ops.gather_elements(x, dim=1, index=ids_keep) + + return x_unmasked, ids_restore + + def forward_features(self, x, mask): + x = self.patch_embed(x) + bsz = x.shape[0] + + x = x + self.pos_embed[:, 1:, :] + x, ids_restore = self.apply_masking(x, mask) + + cls_token = self.cls_token + self.pos_embed[:, :1, :] + cls_token = ops.broadcast_to(cls_token, (bsz, -1, -1)) + cls_token = cls_token.astype(x.dtype) + x = ops.concat((cls_token, x), axis=1) + + for blk in self.blocks: + x = blk(x) + + x = self.norm(x) + return x, ids_restore + + def forward_decoder(self, x, ids_restore): + x = self.decoder_embed(x) + bsz, L, D = x.shape + + mask_len = self.num_patches + 1 - L + mask_tokens = ops.broadcast_to(self.mask_token, (bsz, mask_len, -1)) + mask_tokens = mask_tokens.astype(x.dtype) + + x_ = ops.concat((x[:, 1:, :], mask_tokens), axis=1) + ids_restore = ops.broadcast_to(ops.expand_dims(ids_restore, axis=-1), (-1, -1, D)) + x_ = ops.gather_elements(x_, dim=1, index=ids_restore) + x = ops.concat((x[:, :1, :], x_), axis=1) + + x = x + self.decoder_pos_embed + + for blk in self.decoder_blocks: + x = blk(x) + + x = self.decoder_norm(x) + x = self.decoder_head(x) + + return x[:, 1:, :] + + def forward_loss(self, imgs, pred, mask): + target = self.patchify(imgs) + if self.norm_pix_loss: + mean = target.mean(axis=-1, keep_dims=True) + std = target.std(axis=-1, keepdims=True) + target = (target - mean) / std + + loss = (pred - target) ** 2 + loss = loss.mean(axis=-1) + + mask = mask.astype(loss.dtype) + loss = (loss * mask).sum() / mask.sum() + return loss + + def construct(self, imgs, mask): + bsz = imgs.shape[0] + mask = ops.reshape(mask, (bsz, -1)) + features, ids_restore = self.forward_features(imgs, mask) + pred = self.forward_decoder(features, ids_restore) + loss = self.forward_loss(imgs, pred, mask) + return loss + + +class MAEForFinetune(VisionTransformerEncoder): + def __init__( + self, + img_size: int = 224, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + attn_head_dim: Optional[int] = None, + mlp_ratio: float = 4., + qkv_bias: bool = True, + qk_scale: Optional[float] = None, + pos_drop_rate: float = 0., + proj_drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0., + act_layer: nn.Cell = nn.GELU, + norm_layer: nn.Cell = nn.LayerNorm, + num_classes: int = 1000, + use_mean_pooling: bool = True, + **kwargs + ): + super(MAEForFinetune, self).__init__( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + depth=depth, + num_heads=num_heads, + attn_head_dim=attn_head_dim, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + pos_drop_rate=pos_drop_rate, + proj_drop_rate=proj_drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=drop_path_rate, + init_values=None, + act_layer=act_layer, + norm_layer=norm_layer, + use_abs_pos_emb=True, + use_rel_pos_bias=False, + use_shared_rel_pos_bias=False, + **kwargs + ) + self.use_mean_pooling = use_mean_pooling + if self.use_mean_pooling: + self.fc_norm = norm_layer((embed_dim,)) + else: + self.norm = norm_layer((embed_dim,)) + self.head = nn.Dense(embed_dim, num_classes, weight_init='TruncatedNormal') + + self._init_weights() + self._fix_init_weights() + + def construct(self, x): + x = self.forward_features(x) + if self.use_mean_pooling: + x = x[:, 1:].mean(axis=1) + x = self.fc_norm(x) + else: + x = self.norm(x) + x = x[:, 0] + x = self.head(x) + return x + + +@register_model +def mae_b_16_224_pretrain(pretrained=False, **kwargs): + model = MAEForPretrain( + patch_size=16, embed_dim=768, depth=12, num_heads=12, + decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, + act_layer=partial(nn.GELU, approximate=False), + norm_layer=partial(nn.LayerNorm, epsilon=1e-6), **kwargs + ) + if pretrained: + pass + return model + + +@register_model +def mae_l_16_224_pretrain(pretrained=False, **kwargs): + model = MAEForPretrain( + patch_size=16, embed_dim=1024, depth=24, num_heads=16, + decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, + act_layer=partial(nn.GELU, approximate=False), + norm_layer=partial(nn.LayerNorm, epsilon=1e-6), **kwargs + ) + if pretrained: + pass + return model + + +@register_model +def mae_h_16_224_pretrain(pretrained=False, **kwargs): + model = MAEForPretrain( + patch_size=16, embed_dim=1280, depth=32, num_heads=16, + decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, + act_layer=partial(nn.GELU, approximate=False), + norm_layer=partial(nn.LayerNorm, epsilon=1e-6), **kwargs + ) + if pretrained: + pass + return model + + +@register_model +def mae_b_16_224_finetune(pretrained=True, in_chans=3, num_classes=1000, **kwargs): + default_cfg = default_cfgs["mae_b_16_224_finetune"] + model = MAEForFinetune( + patch_size=16, in_chans=in_chans, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, epsilon=1e-6), num_classes=num_classes, **kwargs + ) + if pretrained: + load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_chans) + return model + + +@register_model +def mae_l_16_224_finetune(pretrained=True, in_chans=3, num_classes=1000, **kwargs): + default_cfg = default_cfgs["mae_l_16_224_finetune"] + model = MAEForFinetune( + patch_size=16, in_chans=in_chans, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, epsilon=1e-6), num_classes=num_classes, **kwargs + ) + if pretrained: + load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_chans) + return model + + +@register_model +def mae_h_14_224_finetune(pretrained=True, in_chans=3, num_classes=1000, **kwargs): + default_cfg = default_cfgs["mae_h_14_224_finetune"] + model = MAEForFinetune( + patch_size=14, in_chans=in_chans, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, epsilon=1e-6), num_classes=num_classes, **kwargs + ) + if pretrained: + load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_chans) + return model diff --git a/mindcv/models/vit_encoder.py b/mindcv/models/vit_encoder.py new file mode 100644 index 000000000..010065273 --- /dev/null +++ b/mindcv/models/vit_encoder.py @@ -0,0 +1,296 @@ +import math +from typing import Optional, Tuple + +import numpy as np + +import mindspore as ms +from mindspore import Parameter, Tensor, nn, ops +from mindspore.common.initializer import TruncatedNormal, initializer + +from .layers.drop_path import DropPath +from .layers.mlp import Mlp +from .layers.patch_embed import PatchEmbed + + +class RelativePositionBiasWithCLS(nn.Cell): + def __init__( + self, + window_size: Tuple[int], + num_heads: int + ): + super(RelativePositionBiasWithCLS, self).__init__() + self.window_size = window_size + self.num_tokens = window_size[0] * window_size[1] + + num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 + # 3: cls to token, token to cls, cls to cls + self.relative_position_bias_table = Parameter( + Tensor(np.zeros((num_relative_distance, num_heads)), dtype=ms.float16) + ) + coords_h = np.arange(window_size[0]).reshape(window_size[0], 1).repeat(window_size[1], 1).reshape(1, -1) + coords_w = np.arange(window_size[1]).reshape(1, window_size[1]).repeat(window_size[0], 0).reshape(1, -1) + coords_flatten = np.concatenate([coords_h, coords_w], axis=0) # [2, Wh * Ww] + + relative_coords = coords_flatten[:, :, np.newaxis] - coords_flatten[:, np.newaxis, :] # [2, Wh * Ww, Wh * Ww] + relative_coords = relative_coords.transpose(1, 2, 0) # [Wh * Ww, Wh * Ww, 2] + relative_coords[:, :, 0] += window_size[0] - 1 + relative_coords[:, :, 1] += window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * window_size[0] - 1 + + relative_position_index = np.zeros((self.num_tokens + 1, self.num_tokens + 1), + dtype=relative_coords.dtype) # [Wh * Ww + 1, Wh * Ww + 1] + relative_position_index[1:, 1:] = relative_coords.sum(-1) + relative_position_index[0, 0:] = num_relative_distance - 3 + relative_position_index[0:, 0] = num_relative_distance - 2 + relative_position_index[0, 0] = num_relative_distance - 1 + relative_position_index = Tensor(relative_position_index.reshape(-1)) + + self.one_hot = nn.OneHot(axis=-1, depth=num_relative_distance, dtype=ms.float16) + self.relative_position_index = Parameter(self.one_hot(relative_position_index), requires_grad=False) + + def construct(self): + out = ops.matmul(self.relative_position_index, self.relative_position_bias_table) + out = ops.reshape(out, (self.num_tokens + 1, self.num_tokens + 1, -1)) + out = ops.transpose(out, (2, 0, 1)) + out = ops.expand_dims(out, 0) + return out + + +class Attention(nn.Cell): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_scale: Optional[float] = None, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + attn_head_dim: Optional[int] = None, + ): + super(Attention, self).__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + if attn_head_dim is not None: + head_dim = attn_head_dim + all_head_dim = head_dim * num_heads + + if qk_scale: + self.scale = Tensor(qk_scale) + else: + self.scale = Tensor(head_dim ** -0.5) + + self.qkv = nn.Dense(dim, all_head_dim * 3, has_bias=qkv_bias) + + self.attn_drop = nn.Dropout(1 - attn_drop) + self.proj = nn.Dense(all_head_dim, dim) + self.proj_drop = nn.Dropout(1 - proj_drop) + + self.mul = ops.Mul() + self.reshape = ops.Reshape() + self.transpose = ops.Transpose() + self.unstack = ops.Unstack(axis=0) + self.attn_matmul_v = ops.BatchMatMul() + self.q_matmul_k = ops.BatchMatMul(transpose_b=True) + + def construct(self, x, rel_pos_bias=None): + b, n, c = x.shape + qkv = self.qkv(x) + qkv = self.reshape(qkv, (b, n, 3, self.num_heads, c // self.num_heads)) + qkv = self.transpose(qkv, (2, 0, 3, 1, 4)) + q, k, v = self.unstack(qkv) + + attn = self.q_matmul_k(q, k) + attn = self.mul(attn, self.scale) + + if rel_pos_bias is not None: + attn = attn + rel_pos_bias + + attn = attn.astype(ms.float32) + attn = ops.softmax(attn, axis=-1) + attn = self.attn_drop(attn) + + out = self.attn_matmul_v(attn, v) + out = self.transpose(out, (0, 2, 1, 3)) + out = self.reshape(out, (b, n, c)) + out = self.proj(out) + out = self.proj_drop(out) + + return out + + +class LayerScale(nn.Cell): + def __init__( + self, + dim: int, + init_values: float = 1e-5 + ): + super(LayerScale, self).__init__() + self.gamma = Parameter(initializer(init_values, dim)) + + def construct(self, x): + return self.gamma * x + + +class Block(nn.Cell): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_scale: Optional[float] = None, + attn_drop: float = 0., + proj_drop: float = 0., + attn_head_dim: Optional[int] = None, + mlp_ratio: float = 4., + drop_path: float = 0., + init_values: Optional[float] = None, + act_layer: nn.Cell = nn.GELU, + norm_layer: nn.Cell = nn.LayerNorm, + ): + super(Block, self).__init__() + self.norm1 = norm_layer((dim,)) + self.attn = Attention( + dim=dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop, proj_drop=proj_drop, attn_head_dim=attn_head_dim, + ) + self.ls1 = LayerScale(dim=dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm2 = norm_layer((dim,)) + self.mlp = Mlp( + in_features=dim, hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, drop=proj_drop + ) + self.ls2 = LayerScale(dim=dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def construct(self, x, rel_pos_bias=None): + x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), rel_pos_bias))) + x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) + return x + + +class VisionTransformerEncoder(nn.Cell): + def __init__( + self, + img_size: int = 224, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + attn_head_dim: Optional[int] = None, + mlp_ratio: float = 4., + qkv_bias: bool = True, + qk_scale: Optional[float] = None, + pos_drop_rate: float = 0., + proj_drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0., + init_values: Optional[float] = 0.1, + act_layer: nn.Cell = nn.GELU, + norm_layer: nn.Cell = nn.LayerNorm, + use_abs_pos_emb: bool = False, + use_rel_pos_bias: bool = False, + use_shared_rel_pos_bias: bool = True, + **kwargs + ): + super(VisionTransformerEncoder, self).__init__() + self.embed_dim = embed_dim + self.patch_embed = PatchEmbed(image_size=img_size, patch_size=patch_size, + in_chans=in_chans, embed_dim=embed_dim) + self.num_patches = self.patch_embed.num_patches + + self.cls_token = Parameter(initializer(TruncatedNormal(0.02), (1, 1, embed_dim))) + + self.pos_embed = Parameter( + initializer(TruncatedNormal(0.02), (1, self.num_patches + 1, embed_dim))) if use_abs_pos_emb else None + self.pos_drop = nn.Dropout(1 - pos_drop_rate) + + if use_shared_rel_pos_bias: + self.rel_pos_bias = RelativePositionBiasWithCLS( + window_size=self.patch_embed.patches_resolution, num_heads=num_heads) + elif use_rel_pos_bias: + self.rel_pos_bias = nn.CellList([ + RelativePositionBiasWithCLS(window_size=self.patch_embed.patches_resolution, + num_heads=num_heads) for _ in range(depth) + ]) + else: + self.rel_pos_bias = None + + dpr = [x.item() for x in np.linspace(0, drop_path_rate, depth)] + self.blocks = nn.CellList([ + Block( + dim=embed_dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop_rate, proj_drop=proj_drop_rate, attn_head_dim=attn_head_dim, + mlp_ratio=mlp_ratio, drop_path=dpr[i], init_values=init_values, + act_layer=act_layer, norm_layer=norm_layer + ) for i in range(depth) + ]) + + def get_num_layers(self): + return len(self.blocks) + + def no_weight_decay(self): + return {'pos_embed', 'cls_token'} + + def _init_weights(self): + for _, cell in self.cells_and_names(): + if isinstance(cell, nn.Dense): + cell.weight.set_data( + initializer(TruncatedNormal(0.02), cell.weight.shape, cell.weight.dtype) + ) + if cell.bias is not None: + cell.bias.set_data( + initializer('zeros', cell.bias.shape, cell.bias.dtype) + ) + elif isinstance(cell, nn.LayerNorm): + cell.gamma.set_data( + initializer('ones', cell.gamma.shape, cell.gamma.dtype) + ) + cell.beta.set_data( + initializer('zeros', cell.beta.shape, cell.beta.dtype) + ) + elif isinstance(cell, nn.Conv2d): + cell.weight.set_data( + initializer(TruncatedNormal(0.02), cell.weight.shape, cell.weight.dtype) + ) + if cell.bias is not None: + cell.bias.set_data( + initializer('zeros', cell.bias.shape, cell.bias.dtype) + ) + + def _fix_init_weights(self): + for i, block in enumerate(self.blocks): + block.attn.proj.weight.set_data( + ops.div(block.attn.proj.weight, math.sqrt(2.0 * (i + 1))) + ) + block.mlp.fc2.weight.set_data( + ops.div(block.mlp.fc2.weight, math.sqrt(2.0 * (i + 1))) + ) + + def forward_features(self, x): + x = self.patch_embed(x) + bsz = x.shape[0] + + cls_token = ops.broadcast_to(self.cls_token, (bsz, -1, -1)) + cls_token = cls_token.astype(x.dtype) + x = ops.concat((cls_token, x), axis=1) + + if self.pos_embed is not None: + x = x + self.pos_embed + x = self.pos_drop(x) + + if isinstance(self.rel_pos_bias, nn.CellList): + for i, blk in enumerate(self.blocks): + rel_pos_bias = self.rel_pos_bias[i]() + x = blk(x, rel_pos_bias) + else: + rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None + for blk in self.blocks: + x = blk(x, rel_pos_bias) + + return x + + def construct(self, x): + return self.forward_features(x) diff --git a/mindcv/optim/__init__.py b/mindcv/optim/__init__.py index 572a3c204..5a61bed91 100644 --- a/mindcv/optim/__init__.py +++ b/mindcv/optim/__init__.py @@ -1,6 +1,6 @@ """ optim init """ from . import optim_factory -from .optim_factory import create_optimizer +from .optim_factory import create_finetune_optimizer, create_optimizer, create_pretrain_optimizer __all__ = [] __all__.extend(optim_factory.__all__) diff --git a/mindcv/optim/optim_factory.py b/mindcv/optim/optim_factory.py index 7fe6bf282..754f79266 100644 --- a/mindcv/optim/optim_factory.py +++ b/mindcv/optim/optim_factory.py @@ -1,5 +1,6 @@ """ optim factory """ import os +from functools import partial from typing import Optional from mindspore import load_checkpoint, load_param_into_net, nn @@ -9,7 +10,7 @@ from .lion import Lion from .nadam import NAdam -__all__ = ["create_optimizer"] +__all__ = ["create_optimizer", "create_pretrain_optimizer", "create_finetune_optimizer"] def init_group_params(params, weight_decay): @@ -76,6 +77,219 @@ def create_optimizer( # if lr is not None: # opt_args.setdefault('lr', lr) + optimizer = get_optimizer( + params, opt_args, opt, lr, weight_decay, momentum, nesterov, loss_scale, schedule_decay, checkpoint_path, eps + ) + + return optimizer + + +def get_pretrain_param_groups(model, weight_decay, skip, skip_keywords): + """get pretrain param groups""" + has_decay, has_decay_name = [], [] + no_decay, no_decay_name = [], [] + + for param in model.trainable_params(): + if ( + len(param.shape) == 1 + or param.name.endswith(".bias") + or (param.name in skip) + or check_keywords_in_name(param.name, skip_keywords) + ): + no_decay.append(param) + no_decay_name.append(param.name) + else: + has_decay.append(param) + has_decay_name.append(param.name) + + return [ + {"params": has_decay, "weight_decay": weight_decay}, + {"params": no_decay, "weight_decay": 0.0}, + {"order_params": model.trainable_params()}, + ] + + +def create_pretrain_optimizer( + model, + opt: str = "adam", + lr: Optional[float] = 1e-3, + weight_decay: float = 0, + momentum: float = 0.9, + nesterov: bool = False, + filter_bias_and_bn: bool = True, + loss_scale: float = 1.0, + schedule_decay: float = 4e-3, + checkpoint_path: str = "", + eps: float = 1e-10, + **kwargs, +): + """build pretrain optimizer""" + + opt = opt.lower() + + skip = {} + skip_keywords = {} + if hasattr(model, "no_weight_decay"): + skip = model.no_weight_decay() + if hasattr(model, "no_weight_decay_keywords"): + skip_keywords = model.no_weight_decay_keywords() + + params = get_pretrain_param_groups(model, weight_decay, skip, skip_keywords) + + opt_args = dict(**kwargs) + # if lr is not None: + # opt_args.setdefault('lr', lr) + + optimizer = get_optimizer( + params, opt_args, opt, lr, weight_decay, momentum, nesterov, loss_scale, schedule_decay, checkpoint_path, eps + ) + + return optimizer + + +def get_vit_layer(name, num_layers): + if name in ("cls_token", "mask_token", "pos_embed"): + return 0 + elif name.startswith("patch_embed"): + return 0 + elif name.startswith("rel_pos_bias"): + return num_layers - 1 + elif name.startswith("blocks"): + layer_id = int(name.split(".")[1]) + return layer_id + 1 + else: + return num_layers - 1 + + +def get_swin_layer(name, num_layers, depths): + if name in ("mask_token",): + return 0 + elif name.startswith("patch_embed"): + return 0 + elif name.startswith("layers"): + layer_id = int(name.split(".")[1]) + block_id = name.split(".")[3] + if block_id == "reduction" or block_id == "norm": + return sum(depths[: layer_id + 1]) + layer_id = sum(depths[:layer_id]) + int(block_id) + return layer_id + 1 + else: + return num_layers - 1 + + +def get_finetune_param_groups( + model, + lr, + weight_decay, + get_layer_func, + scales, + skip, + skip_keywords, +): + parameter_group_names = {} + parameter_group_vars = {} + + for param in model.trainable_params(): + if ( + len(param.shape) == 1 + or param.name.endswith(".bias") + or (param.name in skip) + or check_keywords_in_name(param.name, skip_keywords) + ): + group_name = "no_decay" + this_weight_decay = 0.0 + else: + group_name = "decay" + this_weight_decay = weight_decay + if get_layer_func is not None: + layer_id = get_layer_func(param.name) + group_name = "layer_%d_%s" % (layer_id, group_name) + else: + layer_id = None + + if group_name not in parameter_group_names: + if scales is not None: + scale = scales[layer_id] + else: + scale = 1.0 + + parameter_group_names[group_name] = { + "weight_decay": this_weight_decay, + "params": [], + "lr": [learning_rate * scale for learning_rate in lr], + } + parameter_group_vars[group_name] = { + "weight_decay": this_weight_decay, + "params": [], + "lr": [learning_rate * scale for learning_rate in lr], + } + + parameter_group_vars[group_name]["params"].append(param) + parameter_group_names[group_name]["params"].append(param.name) + + return list(parameter_group_vars.values()) + + +def create_finetune_optimizer( + model, + opt: str = "adam", + lr: Optional[float] = 1e-3, + weight_decay: float = 0, + momentum: float = 0.9, + nesterov: bool = False, + filter_bias_and_bn: bool = True, + loss_scale: float = 1.0, + schedule_decay: float = 4e-3, + checkpoint_path: str = "", + eps: float = 1e-10, + scale: float = 0.75, + **kwargs, +): + if hasattr(model, "get_depths"): + depths = model.get_depths() + num_layers = model.get_num_layers() + get_layer_func = partial(get_swin_layer, num_layers=num_layers + 2, depths=depths) + elif hasattr(model, "get_num_layers"): + num_layers = model.get_num_layers() + get_layer_func = partial(get_vit_layer, num_layers=num_layers + 2) + else: + raise NotImplementedError() + + scales = list(scale**i for i in reversed(range(num_layers + 2))) + + skip = {} + skip_keywords = {} + if hasattr(model, "no_weight_decay"): + skip = model.no_weight_decay() + if hasattr(model, "no_weight_decay_keywords"): + skip_keywords = model.no_weight_decay_keywords() + + params = get_finetune_param_groups(model, lr, weight_decay, get_layer_func, scales, skip, skip_keywords) + + opt_args = dict(**kwargs) + # if lr is not None: + # opt_args.setdefault('lr', lr) + + optimizer = get_optimizer( + params, opt_args, opt, lr, weight_decay, momentum, nesterov, loss_scale, schedule_decay, checkpoint_path, eps + ) + + return optimizer + + +def get_optimizer( + params, + opt_args, + opt: str = "adam", + lr: Optional[float] = 1e-3, + weight_decay: float = 0, + momentum: float = 0.9, + nesterov: bool = False, + loss_scale: float = 1.0, + schedule_decay: float = 4e-3, + checkpoint_path: str = "", + eps: float = 1e-10, +): # non-adaptive: SGD, momentum, and nesterov if opt == "sgd": # note: nn.Momentum may perform better if momentum > 0. @@ -174,3 +388,11 @@ def create_optimizer( load_param_into_net(optimizer, param_dict) return optimizer + + +def check_keywords_in_name(name, keywords=()): + isin = False + for keyword in keywords: + if keyword in name: + isin = True + return isin diff --git a/mindcv/utils/callbacks.py b/mindcv/utils/callbacks.py index b99ebf24c..357dd1d8f 100644 --- a/mindcv/utils/callbacks.py +++ b/mindcv/utils/callbacks.py @@ -6,7 +6,7 @@ import numpy as np import mindspore as ms -from mindspore import ParameterTuple, Tensor, ops +from mindspore import ParameterTuple, Tensor, nn, ops from mindspore.train import Callback, SummaryRecord, load_param_into_net, save_checkpoint from .checkpoint_manager import CheckpointManager @@ -209,7 +209,7 @@ def on_train_epoch_end(self, run_context): self.ckpt_manager.save_ckpoint( cb_params.train_network, num_ckpt=self.ckpt_keep_max, - metric=res[0], + metric=res[0] if len(self.metric_name) > 0 else 0.0, save_path=ckpt_save_path, ) @@ -278,7 +278,10 @@ def _get_lr_from_cbp(self, cb_params): else: # if the optimizer is successfully called, the global_step will actually be the value of next step. optim_step = optimizer.global_step - 1 if optimizer.dynamic_lr: - lr = optimizer.learning_rate(optim_step)[0] + if isinstance(optimizer.learning_rate, nn.CellList): + lr = optimizer.learning_rate[-1](optim_step)[0] + else: + lr = optimizer.learning_rate(optim_step)[0] else: lr = optimizer.learning_rate return lr diff --git a/mindcv/utils/trainer_factory.py b/mindcv/utils/trainer_factory.py index db47a48e6..1384b8bd8 100644 --- a/mindcv/utils/trainer_factory.py +++ b/mindcv/utils/trainer_factory.py @@ -4,11 +4,10 @@ import mindspore as ms from mindspore import Tensor, context from mindspore import dtype as mstype -from mindspore import nn +from mindspore import nn, ops from mindspore.ops import functional as F from mindspore.train import DynamicLossScaleManager, FixedLossScaleManager, Model -from .amp import auto_mixed_precision from .train_step import TrainStep __all__ = [ @@ -88,6 +87,7 @@ def create_trainer( clip_grad: bool = False, clip_value: float = 15.0, gradient_accumulation_steps: int = 1, + tokenizer: Optional[nn.Cell] = None, ): """Create Trainer. @@ -123,11 +123,15 @@ def create_trainer( if not require_customized_train_step(ema, clip_grad, gradient_accumulation_steps, amp_cast_list): mindspore_kwargs = dict( network=network, - loss_fn=loss, + loss_fn=loss, # for MAE and SimMIM, loss is None and metric is None. optimizer=optimizer, - metrics=metrics, + metrics=metrics, # for beit, beit v2, eva and eva-02, metric is None amp_level=amp_level, ) + if tokenizer is not None: + mindspore_kwargs["network"] = WithLossCellForPretrain(network, tokenizer, loss) + mindspore_kwargs.pop("loss_fn") + if loss_scale_type.lower() == "fixed": mindspore_kwargs["loss_scale_manager"] = FixedLossScaleManager( loss_scale=loss_scale, drop_overflow_update=drop_overflow_update @@ -147,9 +151,14 @@ def create_trainer( raise ValueError(f"Loss scale type only support ['fixed', 'dynamic', 'auto'], but got{loss_scale_type}.") model = Model(**mindspore_kwargs) else: # require customized train step - eval_network = nn.WithEvalCell(network, loss, amp_level in ["O2", "O3", "auto"]) - auto_mixed_precision(network, amp_level, amp_cast_list) - net_with_loss = add_loss_network(network, loss, amp_level) + if tokenizer is not None: + net_with_loss = WithLossCellForPretrain(network, tokenizer, loss) # for beit, beit v2, eva, eva-02 + elif loss is None: + net_with_loss = network # for MAE, SimMIM + else: + net_with_loss = nn.WithLossCell(network, loss) + + ms.amp.auto_mixed_precision(net_with_loss, amp_level=amp_level) train_step_kwargs = dict( network=net_with_loss, optimizer=optimizer, @@ -182,6 +191,30 @@ def create_trainer( ) train_step_kwargs["scale_sense"] = update_cell train_step_cell = TrainStep(**train_step_kwargs).set_train() - model = Model(train_step_cell, eval_network=eval_network, metrics=metrics, eval_indexes=[0, 1, 2]) + + if metrics is not None: + model = Model(train_step_cell) + else: + eval_network = nn.WithEvalCell(network, loss, amp_level in ["O2", "O3", "auto"]) + model = Model(train_step_cell, eval_network=eval_network, metrics=metrics, eval_indexes=[0, 1, 2]) # todo: do we need to set model._loss_scale_manager return model + + +class WithLossCellForPretrain(nn.WithLossCell): + def __init__(self, network: nn.Cell, tokenizer: nn.Cell, loss: nn.Cell): + super(WithLossCellForPretrain, self).__init__(network, loss) + self.tokenizer = tokenizer + + def construct(self, x1, x2, mask): + bsz = x1.shape[0] + mask = ops.reshape(mask, (bsz, -1)) + output = self._backbone(x1, mask) + output = ops.transpose(output, (0, 2, 1)) + + label = self.tokenizer(x2) + bool_mask = (1 - mask).astype(ms.bool_) + label = ops.masked_fill(label, bool_mask, value=-100) + + loss = self._loss_fn(output, label) + return loss