Skip to content

Commit

Permalink
Add interpolation to hydra files
Browse files Browse the repository at this point in the history
  • Loading branch information
jloveric committed Dec 16, 2023
1 parent 56781bf commit ed55d3a
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 22 deletions.
1 change: 1 addition & 0 deletions config/high_order_interpolation.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ max_epochs: 100
accelerator: cuda
batch_size: 256
gradient_clip: null # 5.0
accumulate_grad_batches: 1

# Are you training? Otherwise plot the result
train: True
Expand Down
49 changes: 30 additions & 19 deletions config/net/transformer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,30 +9,41 @@ normalize: true
# number of Fourier components.
model_type: high_order_transformer
n: 3
input_segments: 64

segments: 2
# Layers in the form [input, output, segments]
base_width: 64

# attention blocks
#layers: [[10, 10, 64], [10, 5, 2], [5, 5, 2]]
inner: 10
#layers: [[1, ${inner}, 128], [${inner}, ${inner}, 10], [${inner}, ${inner}, 10]]

layers:
[
{
"input": 1,
"hidden": 16,
"output": 16,
"layers": 1,
"segments": 16,
"input_segments": 128,
},
{ "input": 16, "output": 16, "segments": 16 },
{ "input": 16, "output": 16, "segments": 16 },
{ "input": 16, "output": 16, "segments": 16 },
{ "input": 16, "output": 16, "segments": 16 },
{ "input": 16, "hidden": 100, "layers": 1, "segments": 10 },
]
- input: 1
hidden: 16
output: 16
layers: 1
segments: 16
input_segments: 128

- input: 16
output: ${net.base_width}
segments: ${net.segments}

- input: ${net.base_width}
output: ${net.base_width}
segments: ${net.segments}

- input: ${net.base_width}
output: ${net.base_width}
segments: ${net.segments}

- input: ${net.base_width}
output: ${net.base_width}
segments: ${net.segments}

- input: ${net.base_width}
hidden: 100
layers: 1
segments: 10

# Note! output dimension is c=heads*output so different than normal
# And then similarity size is c*c*batch_size
Expand Down
7 changes: 4 additions & 3 deletions examples/high_order_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
logging.getLogger().setLevel(logging.DEBUG)


@hydra.main(config_path="../config", config_name="high_order_interpolation", version_base="1.3")
@hydra.main(
config_path="../config", config_name="high_order_interpolation", version_base="1.3"
)
def run_language_interpolation(cfg: DictConfig):
logger.info(OmegaConf.to_yaml(cfg))
logger.info("Working directory : {}".format(os.getcwd()))
Expand All @@ -36,7 +38,6 @@ def run_language_interpolation(cfg: DictConfig):
create_gutenberg_cache(parent_directory=hydra.utils.get_original_cwd())

if cfg.train is True:

try: # Try is needed for multirun case
if cfg.data.type in dataset_registry:
dataset_generator = dataset_registry[cfg.data.type]
Expand All @@ -46,7 +47,6 @@ def run_language_interpolation(cfg: DictConfig):
)

if cfg.net.model_type == "high_order_transformer":

# dataset_generator is only one type so using the default
datamodule = TransformerDataModule(
characters_per_feature=cfg.data.characters_per_feature,
Expand Down Expand Up @@ -101,6 +101,7 @@ def run_language_interpolation(cfg: DictConfig):
max_epochs=cfg.max_epochs,
accelerator=cfg.accelerator,
gradient_clip_val=cfg.gradient_clip,
accumulate_grad_batches=cfg.accumulate_grad_batches,
)

model = ASCIIPredictionNet(cfg)
Expand Down
3 changes: 3 additions & 0 deletions language_interpolation/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,11 +561,14 @@ def select_network(cfg: DictConfig, device: str = None):

if cfg.initialize.type == "linear":
logger.info("Performing linear initialization")
start_init = time.perf_counter()
initialize_network_polynomial_layers(
model,
max_slope=cfg.initialize.max_slope,
max_offset=cfg.initialize.max_offset,
)
finish_time = time.perf_counter() - start_init
logger.info(f"Finished linear initialization {finish_time}")

return model

Expand Down

0 comments on commit ed55d3a

Please sign in to comment.