Skip to content

Commit

Permalink
support non-strict load of pretrain model
Browse files Browse the repository at this point in the history
  • Loading branch information
zhtmike committed Aug 2, 2023
1 parent b80bb44 commit 5799a3d
Showing 1 changed file with 26 additions and 1 deletion.
27 changes: 26 additions & 1 deletion mindyolo/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,39 @@ def set_default(args):
args.ema_weight = os.path.join(args.ckpt_dir, args.ema_weight) if args.ema_weight else ""


def load_pretrain(network, weight, ema=None, ema_weight=None):
def drop_inconsistent_shape_parameters(model, param_dict):
updated_param_dict = dict()

# TODO: hard code
param_dict = {k.replace('ema.', ''): v for k, v in param_dict.items()}

for param in model.get_parameters():
name = param.name
if name in param_dict:
if param_dict[name].shape == param.shape:
updated_param_dict[name] = param_dict[name]
else:
logger.warning(
f"Dropping checkpoint parameter `{name}` with shape `{param_dict[name].shape}`, "
f"which is inconsistent with cell shape `{param.shape}`"
)
else:
logger.warning(f"Cannot find checkpoint parameter `{name}`.")
return updated_param_dict


def load_pretrain(network, weight, ema=None, ema_weight=None, strict=False):
if weight.endswith(".ckpt"):
param_dict = ms.load_checkpoint(weight)
if not strict:
param_dict = drop_inconsistent_shape_parameters(network, param_dict)
ms.load_param_into_net(network, param_dict)
logger.info(f'Pretrain model load from "{weight}" success.')
if ema:
if ema_weight.endswith(".ckpt"):
param_dict_ema = ms.load_checkpoint(ema_weight)
if not strict:
param_dict_ema = drop_inconsistent_shape_parameters(ema.ema, param_dict_ema)
ms.load_param_into_net(ema.ema, param_dict_ema)
logger.info(f'Ema pretrain model load from "{ema_weight}" success.')
else:
Expand Down

0 comments on commit 5799a3d

Please sign in to comment.