Skip to content

Commit

Permalink
properly specify hidden segments
Browse files Browse the repository at this point in the history
  • Loading branch information
jloveric committed Dec 14, 2023
1 parent a7af9a1 commit e0fe41e
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
9 changes: 8 additions & 1 deletion config/net/transformer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,14 @@ inner: 10
#layers: [[1, ${inner}, 128], [${inner}, ${inner}, 10], [${inner}, ${inner}, 10]]
layers:
[
{ "input": 1, "hidden": 8, "output": 8, "layers": 1, "segments": 128 },
{
"input": 1,
"hidden": 8,
"output": 8,
"layers": 1,
"segments": 4,
"input_segments": 128,
},
{ "input": 8, "output": 8, "segments": 4 },
{ "input": 8, "output": 8, "segments": 4 },
{ "input": 8, "output": 8, "segments": 4 },
Expand Down
3 changes: 2 additions & 1 deletion language_interpolation/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,7 @@ def __init__(
hidden_width = layers[0]["hidden"]
embedding_layers = layers[0]["layers"]
segments = layers[0]["segments"]
input_segments = layers[0]["input_segments"]

mlp_normalization = None
if normalization:
Expand All @@ -355,7 +356,7 @@ def __init__(
layer_type=layer_type,
n=n,
in_width=input_width,
in_segments=segments,
in_segments=input_segments,
out_segments=segments,
hidden_segments=segments,
hidden_layers=embedding_layers,
Expand Down

0 comments on commit e0fe41e

Please sign in to comment.