Skip to content

Commit

Permalink
Merge pull request #181 from zhtmike/non_strict
Browse files Browse the repository at this point in the history
support non-strict load of pretrain model
  • Loading branch information
zhanghuiyao authored Aug 14, 2023
2 parents 0917faa + 0908921 commit 07997fd
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 2 deletions.
27 changes: 26 additions & 1 deletion mindyolo/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,39 @@ def set_default(args):
args.ema_weight = os.path.join(args.ckpt_dir, args.ema_weight) if args.ema_weight else ""


def load_pretrain(network, weight, ema=None, ema_weight=None):
def drop_inconsistent_shape_parameters(model, param_dict):
updated_param_dict = dict()

# TODO: hard code
param_dict = {k.replace('ema.', ''): v for k, v in param_dict.items()}

for param in model.get_parameters():
name = param.name
if name in param_dict:
if param_dict[name].shape == param.shape:
updated_param_dict[name] = param_dict[name]
else:
logger.warning(
f"Dropping checkpoint parameter `{name}` with shape `{param_dict[name].shape}`, "
f"which is inconsistent with cell shape `{param.shape}`"
)
else:
logger.warning(f"Cannot find checkpoint parameter `{name}`.")
return updated_param_dict


def load_pretrain(network, weight, ema=None, ema_weight=None, strict=True):
if weight.endswith(".ckpt"):
param_dict = ms.load_checkpoint(weight)
if not strict:
param_dict = drop_inconsistent_shape_parameters(network, param_dict)
ms.load_param_into_net(network, param_dict)
logger.info(f'Pretrain model load from "{weight}" success.')
if ema:
if ema_weight.endswith(".ckpt"):
param_dict_ema = ms.load_checkpoint(ema_weight)
if not strict:
param_dict_ema = drop_inconsistent_shape_parameters(ema.ema, param_dict_ema)
ms.load_param_into_net(ema.ema, param_dict_ema)
logger.info(f'Ema pretrain model load from "{ema_weight}" success.')
else:
Expand Down
3 changes: 2 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def get_parser_train(parents=None):
parser.add_argument("--profiler", type=ast.literal_eval, default=False, help="collect profiling data or not")
parser.add_argument("--profiler_step_num", type=int, default=1, help="collect profiler data for how many steps.")
parser.add_argument("--opencv_threads_num", type=int, default=2, help="set the number of threads for opencv")
parser.add_argument("--strict_load", type=ast.literal_eval, default=True, help="strictly load the pretrain model")

# args for ModelArts
parser.add_argument("--enable_modelarts", type=ast.literal_eval, default=False, help="enable modelarts")
Expand Down Expand Up @@ -112,7 +113,7 @@ def train(args):
ema = EMA(network, ema_network)
else:
ema = None
load_pretrain(network, args.weight, ema, args.ema_weight) # load pretrain
load_pretrain(network, args.weight, ema, args.ema_weight, args.strict_load) # load pretrain
freeze_layers(network, args.freeze) # freeze Layers
ms.amp.auto_mixed_precision(network, amp_level=args.ms_amp_level)
if ema:
Expand Down

0 comments on commit 07997fd

Please sign in to comment.