From 5799a3dfd76215624ad38cbcda46bc279d576fd9 Mon Sep 17 00:00:00 2001 From: Mike Cheung Date: Wed, 2 Aug 2023 12:01:43 +0800 Subject: [PATCH 1/2] support non-strict load of pretrain model --- mindyolo/utils/utils.py | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/mindyolo/utils/utils.py b/mindyolo/utils/utils.py index 53b74b7a..d3a60aaa 100644 --- a/mindyolo/utils/utils.py +++ b/mindyolo/utils/utils.py @@ -89,14 +89,39 @@ def set_default(args): args.ema_weight = os.path.join(args.ckpt_dir, args.ema_weight) if args.ema_weight else "" -def load_pretrain(network, weight, ema=None, ema_weight=None): +def drop_inconsistent_shape_parameters(model, param_dict): + updated_param_dict = dict() + + # TODO: hard code + param_dict = {k.replace('ema.', ''): v for k, v in param_dict.items()} + + for param in model.get_parameters(): + name = param.name + if name in param_dict: + if param_dict[name].shape == param.shape: + updated_param_dict[name] = param_dict[name] + else: + logger.warning( + f"Dropping checkpoint parameter `{name}` with shape `{param_dict[name].shape}`, " + f"which is inconsistent with cell shape `{param.shape}`" + ) + else: + logger.warning(f"Cannot find checkpoint parameter `{name}`.") + return updated_param_dict + + +def load_pretrain(network, weight, ema=None, ema_weight=None, strict=False): if weight.endswith(".ckpt"): param_dict = ms.load_checkpoint(weight) + if not strict: + param_dict = drop_inconsistent_shape_parameters(network, param_dict) ms.load_param_into_net(network, param_dict) logger.info(f'Pretrain model load from "{weight}" success.') if ema: if ema_weight.endswith(".ckpt"): param_dict_ema = ms.load_checkpoint(ema_weight) + if not strict: + param_dict_ema = drop_inconsistent_shape_parameters(ema.ema, param_dict_ema) ms.load_param_into_net(ema.ema, param_dict_ema) logger.info(f'Ema pretrain model load from "{ema_weight}" success.') else: From 09089219c6566c2f71d825341136ba7ca489783e Mon Sep 17 00:00:00 2001 From: Mike Cheung Date: Wed, 2 Aug 2023 16:20:15 +0800 Subject: [PATCH 2/2] add argment for strict load and set default to be True --- mindyolo/utils/utils.py | 2 +- train.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/mindyolo/utils/utils.py b/mindyolo/utils/utils.py index d3a60aaa..eb8ef071 100644 --- a/mindyolo/utils/utils.py +++ b/mindyolo/utils/utils.py @@ -110,7 +110,7 @@ def drop_inconsistent_shape_parameters(model, param_dict): return updated_param_dict -def load_pretrain(network, weight, ema=None, ema_weight=None, strict=False): +def load_pretrain(network, weight, ema=None, ema_weight=None, strict=True): if weight.endswith(".ckpt"): param_dict = ms.load_checkpoint(weight) if not strict: diff --git a/train.py b/train.py index 393ba490..710e1727 100644 --- a/train.py +++ b/train.py @@ -69,6 +69,7 @@ def get_parser_train(parents=None): parser.add_argument("--profiler", type=ast.literal_eval, default=False, help="collect profiling data or not") parser.add_argument("--profiler_step_num", type=int, default=1, help="collect profiler data for how many steps.") parser.add_argument("--opencv_threads_num", type=int, default=2, help="set the number of threads for opencv") + parser.add_argument("--strict_load", type=ast.literal_eval, default=True, help="strictly load the pretrain model") # args for ModelArts parser.add_argument("--enable_modelarts", type=ast.literal_eval, default=False, help="enable modelarts") @@ -112,7 +113,7 @@ def train(args): ema = EMA(network, ema_network) else: ema = None - load_pretrain(network, args.weight, ema, args.ema_weight) # load pretrain + load_pretrain(network, args.weight, ema, args.ema_weight, args.strict_load) # load pretrain freeze_layers(network, args.freeze) # freeze Layers ms.amp.auto_mixed_precision(network, amp_level=args.ms_amp_level) if ema: