From 2a07185b368c66eea195531501e658fc30b22b91 Mon Sep 17 00:00:00 2001 From: Nemo <132769294+ChongWei905@users.noreply.github.com> Date: Tue, 23 Jul 2024 11:47:26 +0800 Subject: [PATCH] fix: fix all zero initialized tensor problem for vit (#794) Co-authored-by: ChongWei905 --- mindcv/models/vit.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mindcv/models/vit.py b/mindcv/models/vit.py index c158e6a3..c8ee0967 100644 --- a/mindcv/models/vit.py +++ b/mindcv/models/vit.py @@ -329,7 +329,9 @@ def get_num_layers(self): def _init_weights(self): w = self.patch_embed.proj.weight w_shape_flatted = (w.shape[0], functools.reduce(lambda x, y: x*y, w.shape[1:])) - w.set_data(initializer(XavierUniform(), w_shape_flatted, w.dtype).reshape(w.shape)) + w_value = initializer(XavierUniform(), w_shape_flatted, w.dtype) + w_value.init_data() + w.set_data(w_value.reshape(w.shape)) for _, cell in self.cells_and_names(): if isinstance(cell, nn.Dense): cell.weight.set_data(