diff --git a/config/net/transformer.yaml b/config/net/transformer.yaml index 5ca5aea..699bfcc 100644 --- a/config/net/transformer.yaml +++ b/config/net/transformer.yaml @@ -12,10 +12,18 @@ n: 3 input_segments: 64 segments: 2 # Layers in the form [input, output, segments] + +# attention blocks #layers: [[10, 10, 64], [10, 5, 2], [5, 5, 2]] layers: [[1, 10, 128], [10, 10, 10], [10, 10, 10], [10, 10, 10]] +output_layer : + segments: 10 + hidden_layers: 1 + hidden_width: 10 + + # Note! output dimension is c=heads*output so different than normal # And then similarity size is c*c*batch_size -heads: 10 +heads: 10 \ No newline at end of file diff --git a/language_interpolation/networks.py b/language_interpolation/networks.py index 9ba4d4c..90864cd 100644 --- a/language_interpolation/networks.py +++ b/language_interpolation/networks.py @@ -317,6 +317,9 @@ def __init__( layer_type: str, layers: list, n: int, + output_hidden_layers: int, + output_hidden_width:int, + output_segments: int, normalization: None, heads: int = 1, device: str = "cuda", @@ -349,13 +352,18 @@ def __init__( self.layer.append(new_layer) out_dim = layers[-1][1] - self._output_layer = high_order_fc_layers( + self._output_layer = HighOrderMLP( layer_type=layer_type, n=n, - segments=segments, - in_features=out_dim, - out_features=128, + in_width=out_dim, + in_segments = output_segments, + out_segments=output_segments, + hidden_segments=output_segments, + hidden_layers=output_hidden_layers, + hidden_width=output_hidden_width, + out_width=128, device=self._device, + normalization=normalization ) # Make the positions 0 to max_context-1 @@ -445,6 +453,9 @@ def select_network(cfg: DictConfig, device: str = None): device=cfg.accelerator, heads=cfg.net.heads, max_context=cfg.data.max_features, + output_hidden_layers=cfg.net.output_layer.hidden_layers, + output_hidden_width=cfg.net.output_layer.hidden_width, + output_segments=cfg.net.output_layer.segments ) elif cfg.net.model_type == "high_order": diff --git a/tests/test_attention_network.py b/tests/test_attention_network.py index b8143a7..d7d3562 100644 --- a/tests/test_attention_network.py +++ b/tests/test_attention_network.py @@ -40,6 +40,9 @@ def test_attention_network(): device="cpu", heads=2, max_context=max_features, + output_segments=2, + output_hidden_layers=1, + output_hidden_width=5 ) result = network(input_data) print("final result", result)