Skip to content

Commit

Permalink
Building a better model
Browse files Browse the repository at this point in the history
  • Loading branch information
jloveric committed Dec 4, 2023
1 parent afd2309 commit 67e86e0
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 11 deletions.
13 changes: 7 additions & 6 deletions config/net/transformer.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Small network that runs pretty fast. See at
# most 50% accuracy.

layer_type: "discontinuous"
layer_type: "continuous"
normalize: true

# Polynomial interpolation points. Polynomial order
Expand All @@ -15,15 +15,16 @@ segments: 2

# attention blocks
#layers: [[10, 10, 64], [10, 5, 2], [5, 5, 2]]
layers: [[1, 10, 128], [10, 10, 10], [10, 10, 10], [10, 10, 10]]
inner: 10
#layers: [[1, ${inner}, 128], [${inner}, ${inner}, 10], [${inner}, ${inner}, 10]]
layers: [[1, 10, 128], [10, 10, 10], [10, 10, 10]]

output_layer :
output_layer:
segments: 10
hidden_layers: 1
hidden_width: 10

hidden_width: 100

# Note! output dimension is c=heads*output so different than normal
# And then similarity size is c*c*batch_size

heads: 10
heads: 100
10 changes: 5 additions & 5 deletions language_interpolation/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def forward(
kth = kt[:, :, start:end]
vth = vt[:, :, start:end]

qkh = torch.nn.functional.softmax(qth @ kth.transpose(1, 2))
qkh = torch.nn.functional.softmax(qth @ kth.transpose(1, 2), dim=2)
# print("torch number of elements per head", torch.numel(qkh))
# Matrix multiply of last 2 dimensions
qkv_list.append(qkh @ vth)
Expand Down Expand Up @@ -318,7 +318,7 @@ def __init__(
layers: list,
n: int,
output_hidden_layers: int,
output_hidden_width:int,
output_hidden_width: int,
output_segments: int,
normalization: None,
heads: int = 1,
Expand Down Expand Up @@ -356,14 +356,14 @@ def __init__(
layer_type=layer_type,
n=n,
in_width=out_dim,
in_segments = output_segments,
in_segments=output_segments,
out_segments=output_segments,
hidden_segments=output_segments,
hidden_layers=output_hidden_layers,
hidden_width=output_hidden_width,
out_width=128,
device=self._device,
normalization=normalization
normalization=normalization,
)

# Make the positions 0 to max_context-1
Expand Down Expand Up @@ -455,7 +455,7 @@ def select_network(cfg: DictConfig, device: str = None):
max_context=cfg.data.max_features,
output_hidden_layers=cfg.net.output_layer.hidden_layers,
output_hidden_width=cfg.net.output_layer.hidden_width,
output_segments=cfg.net.output_layer.segments
output_segments=cfg.net.output_layer.segments,
)

elif cfg.net.model_type == "high_order":
Expand Down

0 comments on commit 67e86e0

Please sign in to comment.