Skip to content

Commit

Permalink
optional flag to interpolate positional encoding
Browse files Browse the repository at this point in the history
  • Loading branch information
thayeral committed Jan 6, 2025
1 parent b40a6a3 commit d9033f4
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions src/patch_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,8 @@ def __init__(
temporal_patch_size=1,
embed_dim=768,
channels=1,
cls_token=False
cls_token=False,
interpolate=False,
):
super().__init__()
self.input_shape = input_shape
Expand All @@ -252,6 +253,7 @@ def __init__(
self.embed_dim = embed_dim
self.channels = channels
self.cls_token = cls_token
self.interpolate = interpolate

self.num_patches, self.token_shape = calc_num_patches(
input_fmt=self.input_fmt,
Expand Down Expand Up @@ -466,4 +468,7 @@ def interpolate_positional_encoding(self, x, pos_embed):
return pos_embed

def forward(self, x):
return self.interpolate_positional_encoding(x, self.pos_embed)
if self.interpolate:
return self.interpolate_positional_encoding(x, self.pos_embed)
else:
return self.pos_embed

0 comments on commit d9033f4

Please sign in to comment.