Skip to content

Commit

Permalink
Rename to layer
Browse files Browse the repository at this point in the history
  • Loading branch information
jloveric committed May 4, 2024
1 parent 1c90504 commit adb11eb
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
8 changes: 6 additions & 2 deletions language_interpolation/dual_convolutional_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torch import Tensor


class DualConvolutionalNetwork(torch.nn.Module):
class DualConvolutionalLayer(torch.nn.Module):
def __init__(
self,
n: str,
Expand Down Expand Up @@ -53,14 +53,18 @@ def forward(self, x: Tensor):
"""
x has shape [B, L, D]
"""

xshape = x.shape
nx = x.reshape(xshape[0]*xshape[1],xshape[2])

val = self.input_layer(nx)
val = val.reshape(x.shape[0],x.shape[1],-1)

# Gradients apparently automatically accumulate, though probably want
# some normalization here
depth = 0
while val.shape[1] > 1:
depth+=1
if val.shape[1] % 2 == 1:
# Add padding to the end, hope this doesn't bust anything
val = torch.cat([val, torch.zeros(val.shape[0], 1, val.shape[2])], dim=1)
Expand Down
6 changes: 3 additions & 3 deletions tests/test_dual_convolution.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import pytest

from language_interpolation.dual_convolutional_network import DualConvolutionalNetwork
from language_interpolation.dual_convolutional_network import DualConvolutionalLayer
import torch


def test_dual_convolution():
net = DualConvolutionalNetwork(n=3, in_width=1, out_width=10, hidden_layers=2, hidden_width=10, in_segments=128, segments=5)
net = DualConvolutionalLayer(n=3, in_width=1, out_width=10, hidden_layers=2, hidden_width=10, in_segments=128, segments=5)
x = torch.rand(10, 15, 1) # character level
res = net(x)
print('res', res)
print('res', res.shape)

0 comments on commit adb11eb

Please sign in to comment.