Skip to content

Commit

Permalink
fix: fix all zero initialized tensor problem for vit (#794)
Browse files Browse the repository at this point in the history
Co-authored-by: ChongWei905 <weichong4@huawei.com>
  • Loading branch information
ChongWei905 and ChongWei905 authored Jul 23, 2024
1 parent 5c56e78 commit 2a07185
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion mindcv/models/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,9 @@ def get_num_layers(self):
def _init_weights(self):
w = self.patch_embed.proj.weight
w_shape_flatted = (w.shape[0], functools.reduce(lambda x, y: x*y, w.shape[1:]))
w.set_data(initializer(XavierUniform(), w_shape_flatted, w.dtype).reshape(w.shape))
w_value = initializer(XavierUniform(), w_shape_flatted, w.dtype)
w_value.init_data()
w.set_data(w_value.reshape(w.shape))
for _, cell in self.cells_and_names():
if isinstance(cell, nn.Dense):
cell.weight.set_data(
Expand Down

0 comments on commit 2a07185

Please sign in to comment.