Skip to content

Commit

Permalink
Updating with MaxAbsNormalizationLast
Browse files Browse the repository at this point in the history
  • Loading branch information
jloveric committed Dec 9, 2023
1 parent 4ecd468 commit 13b36fc
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 14 deletions.
19 changes: 8 additions & 11 deletions language_interpolation/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
initialize_network_polynomial_layers,
initialize_polynomial_layer
)
from high_order_layers_torch.layers import MaxAbsNormalization, high_order_fc_layers
from high_order_layers_torch.layers import MaxAbsNormalizationLast, high_order_fc_layers
from torchmetrics import Accuracy
from torch import Tensor

Expand Down Expand Up @@ -192,13 +192,15 @@ def forward(
qt = self.normalization(qt.view(query.shape[0], query.shape[1], qt.shape[1]))
kt = self.normalization(kt.view(key.shape[0], key.shape[1], kt.shape[1]))
vt = self.normalization(vt.view(value.shape[0], value.shape[1], vt.shape[1]))

qth = qt.reshape(qt.shape[0], qt.shape[1], self.heads, -1)
kth = kt.reshape(kt.shape[0], kt.shape[1], self.heads, -1)
vth = vt.reshape(vt.shape[0], vt.shape[1], self.heads, -1)

qkh = torch.nn.functional.softmax(torch.einsum('blhd,brhd->blrh',qth,kth), dim=3)
res =torch.einsum('blrh,brhd->blhd',qkh, vth)
res = torch.einsum('blrh,brhd->blhd',qkh, vth)



v = res.reshape(res.shape[0] * res.shape[1], -1)
output = self.output_layer(v)
Expand Down Expand Up @@ -354,7 +356,7 @@ def __init__(
out_dim = layers[-1][1]
mlp_normalization = None
if normalization is not None :
mlp_normalization = MaxAbsNormalization
mlp_normalization = MaxAbsNormalizationLast

self._output_layer = HighOrderMLP(
layer_type=layer_type,
Expand Down Expand Up @@ -404,17 +406,12 @@ def forward(self, x: Tensor) -> Tensor:

average = torch.sum(res, dim=1) / res.shape[1]

"""
if self.normalization is not None :
final = self.normalization(self._output_layer(average))
else :
final = self._output_layer(average)
"""

final = self._output_layer(average)
# print("final network outputs size", torch.numel(final))

torch.cuda.empty_cache()
#torch.cuda.empty_cache()
return final
# return self.model(x)

Expand Down Expand Up @@ -459,7 +456,7 @@ def select_network(cfg: DictConfig, device: str = None):

normalizer = None
if normalization==True :
normalizer=MaxAbsNormalization(eps=1e-6, dim=2)
normalizer=MaxAbsNormalizationLast(eps=1e-6)


model = HighOrderAttentionNetwork(
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ pytorch-lightning = "^2.0.0"
langchain = "^0.0.334"
hydra-core = "^1.3.2"
tensorboard = "^2.15.1"
high-order-layers-torch = "^2.2.0"
high-order-layers-torch = "^2.2.1"

[tool.poetry.group.dev.dependencies]
black = "^23.11.0"
Expand Down
4 changes: 2 additions & 2 deletions tests/test_attention_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from language_interpolation.networks import HighOrderAttentionNetwork, large_character_spacing, small_character_spacing
from language_interpolation.lightning_datamodule import TransformerDataModule
from high_order_layers_torch.layers import MaxAbsNormalization
from high_order_layers_torch.layers import MaxAbsNormalizationLast
from omegaconf import DictConfig
from language_interpolation.utils import generate_transformer_text
import torch
Expand Down Expand Up @@ -32,7 +32,7 @@ def test_attention_network():
assert input_data.shape[0] == 32
assert input_data.shape[2] == 10

normalization = MaxAbsNormalization(eps=1e-6, dim=2)
normalization = MaxAbsNormalizationLast(eps=1e-6)

network = HighOrderAttentionNetwork(
layers=[[10, 5, 3], [5, 5, 2]],
Expand Down

0 comments on commit 13b36fc

Please sign in to comment.