diff --git a/mindyolo/utils/utils.py b/mindyolo/utils/utils.py index 53b74b7a..d3a60aaa 100644 --- a/mindyolo/utils/utils.py +++ b/mindyolo/utils/utils.py @@ -89,14 +89,39 @@ def set_default(args): args.ema_weight = os.path.join(args.ckpt_dir, args.ema_weight) if args.ema_weight else "" -def load_pretrain(network, weight, ema=None, ema_weight=None): +def drop_inconsistent_shape_parameters(model, param_dict): + updated_param_dict = dict() + + # TODO: hard code + param_dict = {k.replace('ema.', ''): v for k, v in param_dict.items()} + + for param in model.get_parameters(): + name = param.name + if name in param_dict: + if param_dict[name].shape == param.shape: + updated_param_dict[name] = param_dict[name] + else: + logger.warning( + f"Dropping checkpoint parameter `{name}` with shape `{param_dict[name].shape}`, " + f"which is inconsistent with cell shape `{param.shape}`" + ) + else: + logger.warning(f"Cannot find checkpoint parameter `{name}`.") + return updated_param_dict + + +def load_pretrain(network, weight, ema=None, ema_weight=None, strict=False): if weight.endswith(".ckpt"): param_dict = ms.load_checkpoint(weight) + if not strict: + param_dict = drop_inconsistent_shape_parameters(network, param_dict) ms.load_param_into_net(network, param_dict) logger.info(f'Pretrain model load from "{weight}" success.') if ema: if ema_weight.endswith(".ckpt"): param_dict_ema = ms.load_checkpoint(ema_weight) + if not strict: + param_dict_ema = drop_inconsistent_shape_parameters(ema.ema, param_dict_ema) ms.load_param_into_net(ema.ema, param_dict_ema) logger.info(f'Ema pretrain model load from "{ema_weight}" success.') else: