diff --git a/mindcv/models/vit.py b/mindcv/models/vit.py index c158e6a3..c8ee0967 100644 --- a/mindcv/models/vit.py +++ b/mindcv/models/vit.py @@ -329,7 +329,9 @@ def get_num_layers(self): def _init_weights(self): w = self.patch_embed.proj.weight w_shape_flatted = (w.shape[0], functools.reduce(lambda x, y: x*y, w.shape[1:])) - w.set_data(initializer(XavierUniform(), w_shape_flatted, w.dtype).reshape(w.shape)) + w_value = initializer(XavierUniform(), w_shape_flatted, w.dtype) + w_value.init_data() + w.set_data(w_value.reshape(w.shape)) for _, cell in self.cells_and_names(): if isinstance(cell, nn.Dense): cell.weight.set_data(