Skip to content

Commit

Permalink
Adding dual conv example
Browse files Browse the repository at this point in the history
  • Loading branch information
jloveric committed May 4, 2024
1 parent adb11eb commit 7b1ad16
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 8 deletions.
16 changes: 16 additions & 0 deletions config/net/dual_convolution.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
layer_type: continuous
normalize: layer #maxabs

model_type: dual_convolution
n: 3

segments: 2
base_width: 8

in_width: 1
out_width: 128
embedding_dimension: 256
hidden_width: 128
hidden_layers: 2
in_segments: 128
accelerator: cuda
54 changes: 46 additions & 8 deletions language_interpolation/dual_convolutional_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(
device: str = "cpu",
):
super().__init__()

self._out_width = out_width

self.input_layer = HighOrderMLP(
Expand All @@ -39,7 +39,7 @@ def __init__(
self.equal_layers = HighOrderMLP(
layer_type="continuous",
n=n,
in_width=2*out_width,
in_width=2 * out_width,
out_width=out_width,
hidden_layers=hidden_layers,
hidden_width=hidden_width,
Expand All @@ -55,23 +55,61 @@ def forward(self, x: Tensor):
"""

xshape = x.shape
nx = x.reshape(xshape[0]*xshape[1],xshape[2])
nx = x.reshape(xshape[0] * xshape[1], xshape[2])

val = self.input_layer(nx)
val = val.reshape(x.shape[0],x.shape[1],-1)
val = val.reshape(x.shape[0], x.shape[1], -1)

# Gradients apparently automatically accumulate, though probably want
# some normalization here
depth = 0
while val.shape[1] > 1:
depth+=1
depth += 1
if val.shape[1] % 2 == 1:
# Add padding to the end, hope this doesn't bust anything
val = torch.cat([val, torch.zeros(val.shape[0], 1, val.shape[2])], dim=1)
val = torch.cat(
[val, torch.zeros(val.shape[0], 1, val.shape[2])], dim=1
)

valshape = val.shape
val = val.reshape(-1,2*self._out_width)
val = val.reshape(-1, 2 * self._out_width)
val = self.equal_layers(val)
val = val.reshape(valshape[0],-1,self._out_width)
val = val.reshape(valshape[0], -1, self._out_width)
return val


class DualConvolutionNetwork(torch.nn.Module):
def __init__(
self,
n: str,
in_width: int,
out_width: int,
embedding_dimension: int,
hidden_layers: int,
hidden_width: int,
in_segments: int = None,
segments: int = None,
device: str = "cpu",
):

self.dual_layer = DualConvolutionalLayer(
n=n,
in_width=in_width,
out_width=embedding_dimension,
hidden_layer=hidden_layers,
hidden_width=hidden_width,
in_segments=in_segments,
segments=segments,
device=device,
)

self.output_mlp = torch.nn.Linear(
in_features=embedding_dimension,
out_features=out_width,
device=device,
)

def forward(self, x):
x = self.dual_layer(x)
x = self.output_mlp(x)
return x
13 changes: 13 additions & 0 deletions language_interpolation/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from lion_pytorch import Lion
from high_order_layers_torch.sparse_optimizers import SparseLion
from language_interpolation.state_space_network import Mamba
from language_interpolation.dual_convolutional_network import DualConvolutionNetwork

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.DEBUG)
Expand Down Expand Up @@ -845,6 +846,18 @@ def forward(self, x) :
max_context=cfg.data.max_features,
non_linearity=torch.nn.ReLU(),
)
elif cfg.net.model_type == "dual_convolutional_network":
model = DualConvolutionNetwork(
n=cfg.net.n,
in_width=cfg.net.in_width,
out_width=cfg.net.out_width,
embedding_dimension=cfg.net.embedding_dimension,
hidden_width=cfg.net.hidden_width,
hidden_layers=cfg.net.hidden_layers,
in_segments=cfg.net.in_segments,
segments=cfg.net.segments,
device=cfg.accelerator
)
elif cfg.net.model_type == "high_order":
"""
Uniform high order model. All layers are high order.
Expand Down

0 comments on commit 7b1ad16

Please sign in to comment.