Skip to content

Commit

Permalink
fix: fix bugs of pvt
Browse files Browse the repository at this point in the history
  • Loading branch information
The-truthh committed Mar 12, 2024
1 parent 29b7086 commit 9d2f74d
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions mindcv/models/pvt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,16 @@
from functools import partial
from typing import Optional

import numpy as np

import mindspore
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import Tensor
from mindspore.common import initializer as weight_init

from .helpers import load_pretrained
from .layers.compatibility import Dropout
from .layers.compatibility import Dropout, Interpolate
from .layers.drop_path import DropPath
from .layers.identity import Identity
from .layers.mlp import Mlp
Expand Down Expand Up @@ -198,9 +200,7 @@ def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, emb
self.num_classes = num_classes
self.depths = depths
self.num_stages = num_stages
start = Tensor(0, mindspore.float32)
stop = Tensor(drop_path_rate, mindspore.float32)
dpr = [float(x) for x in ops.linspace(start, stop, sum(depths))] # stochastic depth decay rule
dpr = [x.item() for x in np.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
cur = 0
b_list = []
self.pos_embed = []
Expand Down Expand Up @@ -292,8 +292,9 @@ def _get_pos_embed(self, pos_embed, ph, pw, H, W):
return pos_embed
else:
pos_embed = self.transpose(self.reshape(pos_embed, (1, ph, pw, -1)), (0, 3, 1, 2))
resize_bilinear = ops.ResizeBilinear((H, W))
pos_embed = resize_bilinear(pos_embed)
# interpolate_fn = Interpolate(mode="bilinear", align_corners=False)
# pos_embed = interpolate_fn(pos_embed, (H, W))
pos_embed = ops.interpolate(pos_embed, (H, W), mode="bilinear", align_corners=False)

pos_embed = self.transpose(self.reshape(pos_embed, (1, -1, H * W)), (0, 2, 1))

Expand Down

0 comments on commit 9d2f74d

Please sign in to comment.