Skip to content

Commit

Permalink
merge commits fixing bugs from branch main to branch v0.4.0 (#799)
Browse files Browse the repository at this point in the history
* fix: fix all zero initialized tensor problem for vit (#794)

Co-authored-by: ChongWei905 <weichong4@huawei.com>
(cherry picked from commit 2a07185)

* feat: add validate shuffle config and set default to False (#793)

(cherry picked from commit 1c283bc)

Co-authored-by: ChongWei905 <weichong4@huawei.com>
(cherry picked from commit 69ae862)

* fix: fix load checkpoint failure for deeplabv3 (#790)

Co-authored-by: ChongWei905 <weichong4@huawei.com>
(cherry picked from commit c744121)

* ops.ResizeBilinear has been deprecated, thus using ops.ResizeBilinearV2 instead (#789)

(cherry picked from commit a6d45f3)

---------

Co-authored-by: Pingqi Li <58093835+PingqiLi@users.noreply.github.com>
  • Loading branch information
ChongWei905 and PingqiLi authored Jul 25, 2024
1 parent d0eaa0f commit 9cddc51
Show file tree
Hide file tree
Showing 7 changed files with 10 additions and 5 deletions.
2 changes: 2 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,8 @@ def create_parser():
help='Loss scale (default=1.0)')
group.add_argument('--drop_overflow_update', type=bool, default=False,
help='Whether to execute optimizer if there is an overflow (default=False)')
group.add_argument('--eval_shuffle', type=bool, default=False,
help='Whether to shuffle the evaluation data (default=False)')

return parser_config, parser
# fmt: on
Expand Down
2 changes: 1 addition & 1 deletion examples/det/ssd/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def construct(self, inputs):
top = len(inputs) - i - 1
down = top - 1
size = ops.shape(inputs[down])
top_down = ops.ResizeBilinear((size[2], size[3]))(features[-1])
top_down = ops.ResizeBilinearV2()(features[-1], (size[2], size[3]))
top_down = top_down + image_features[down]
features = features + (top_down,)

Expand Down
2 changes: 1 addition & 1 deletion examples/seg/deeplabv3/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def on_train_end(self, run_context):

def get_segment_train_callback(args, steps_per_epoch, rank_id):
callbacks = [TimeMonitor(data_size=steps_per_epoch), LossMonitor()]
if rank_id == 0:
if rank_id == 0 or rank_id is None:
ckpt_config = CheckpointConfig(
save_checkpoint_steps=args.save_steps,
keep_checkpoint_max=args.keep_checkpoint_max,
Expand Down
2 changes: 1 addition & 1 deletion examples/seg/deeplabv3/deeplabv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ class DeepLabInferNetwork(nn.Cell):
"""

def __init__(self, network, input_format="NCHW"):
super(DeepLabInferNetwork, self).__init__()
super(DeepLabInferNetwork, self).__init__(auto_prefix=False)
self.network = network
self.softmax = nn.Softmax(axis=1)
self.format = input_format
Expand Down
2 changes: 1 addition & 1 deletion examples/seg/deeplabv3/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def train(args):
callbacks = get_segment_train_callback(args, steps_per_epoch, rank_id)

# eval when train
if args.eval_while_train and rank_id == 0:
if args.eval_while_train and (rank_id == 0 or rank_id is None):
eval_model = DeepLabInferNetwork(deeplab, input_format=args.input_format)
eval_dataset = create_segment_dataset(
name=args.dataset,
Expand Down
4 changes: 3 additions & 1 deletion mindcv/models/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,9 @@ def get_num_layers(self):
def _init_weights(self):
w = self.patch_embed.proj.weight
w_shape_flatted = (w.shape[0], functools.reduce(lambda x, y: x*y, w.shape[1:]))
w.set_data(initializer(XavierUniform(), w_shape_flatted, w.dtype).reshape(w.shape))
w_value = initializer(XavierUniform(), w_shape_flatted, w.dtype)
w_value.init_data()
w.set_data(w_value.reshape(w.shape))
for _, cell in self.cells_and_names():
if isinstance(cell, nn.Dense):
cell.weight.set_data(
Expand Down
1 change: 1 addition & 0 deletions validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def validate(args):
split=args.val_split,
num_parallel_workers=args.num_parallel_workers,
download=args.dataset_download,
shuffle=args.eval_shuffle,
)

# create transform
Expand Down

0 comments on commit 9cddc51

Please sign in to comment.