diff --git a/mindyolo/utils/utils.py b/mindyolo/utils/utils.py index dbdd54c6..a9ba5bd3 100644 --- a/mindyolo/utils/utils.py +++ b/mindyolo/utils/utils.py @@ -89,14 +89,39 @@ def set_default(args): args.ema_weight = os.path.join(args.ckpt_dir, args.ema_weight) if args.ema_weight else "" -def load_pretrain(network, weight, ema=None, ema_weight=None): +def drop_inconsistent_shape_parameters(model, param_dict): + updated_param_dict = dict() + + # TODO: hard code + param_dict = {k.replace('ema.', ''): v for k, v in param_dict.items()} + + for param in model.get_parameters(): + name = param.name + if name in param_dict: + if param_dict[name].shape == param.shape: + updated_param_dict[name] = param_dict[name] + else: + logger.warning( + f"Dropping checkpoint parameter `{name}` with shape `{param_dict[name].shape}`, " + f"which is inconsistent with cell shape `{param.shape}`" + ) + else: + logger.warning(f"Cannot find checkpoint parameter `{name}`.") + return updated_param_dict + + +def load_pretrain(network, weight, ema=None, ema_weight=None, strict=True): if weight.endswith(".ckpt"): param_dict = ms.load_checkpoint(weight) + if not strict: + param_dict = drop_inconsistent_shape_parameters(network, param_dict) ms.load_param_into_net(network, param_dict) logger.info(f'Pretrain model load from "{weight}" success.') if ema: if ema_weight.endswith(".ckpt"): param_dict_ema = ms.load_checkpoint(ema_weight) + if not strict: + param_dict_ema = drop_inconsistent_shape_parameters(ema.ema, param_dict_ema) ms.load_param_into_net(ema.ema, param_dict_ema) logger.info(f'Ema pretrain model load from "{ema_weight}" success.') else: diff --git a/train.py b/train.py index 393ba490..710e1727 100644 --- a/train.py +++ b/train.py @@ -69,6 +69,7 @@ def get_parser_train(parents=None): parser.add_argument("--profiler", type=ast.literal_eval, default=False, help="collect profiling data or not") parser.add_argument("--profiler_step_num", type=int, default=1, help="collect profiler data for how many steps.") parser.add_argument("--opencv_threads_num", type=int, default=2, help="set the number of threads for opencv") + parser.add_argument("--strict_load", type=ast.literal_eval, default=True, help="strictly load the pretrain model") # args for ModelArts parser.add_argument("--enable_modelarts", type=ast.literal_eval, default=False, help="enable modelarts") @@ -112,7 +113,7 @@ def train(args): ema = EMA(network, ema_network) else: ema = None - load_pretrain(network, args.weight, ema, args.ema_weight) # load pretrain + load_pretrain(network, args.weight, ema, args.ema_weight, args.strict_load) # load pretrain freeze_layers(network, args.freeze) # freeze Layers ms.amp.auto_mixed_precision(network, amp_level=args.ms_amp_level) if ema: