Skip to content

Commit

Permalink
Add sparse lion and use as default
Browse files Browse the repository at this point in the history
  • Loading branch information
jloveric committed Dec 22, 2023
1 parent a17cdcb commit 385e818
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 3 deletions.
2 changes: 1 addition & 1 deletion config/high_order_interpolation.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ prompts:
topk: 3
num_predict: 100
defaults:
- optimizer: lion
- optimizer: sparse_lion
- net: small
#- override hydra/sweeper: nevergrad
# hydra:
Expand Down
7 changes: 7 additions & 0 deletions config/optimizer/sparse_lion.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
name: sparse_lion
# lr supposed to be 3 to 5 times smaller than adam
lr: 1e-5
patience: 5
factor: 0.1
gamma: 0.9
scheduler: plateau # exponential
7 changes: 6 additions & 1 deletion language_interpolation/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import logging
import time
from lion_pytorch import Lion
from high_order_layers_torch.sparse_optimizers import SparseLion

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.DEBUG)
Expand Down Expand Up @@ -160,6 +161,10 @@ def configure_optimizers(self):
optimizer = Lion(
self.parameters(), lr=self.cfg.optimizer.lr, weight_decay=0.0
)
elif self.cfg.optimizer.name == "sparse_lion":
optimizer = SparseLion(
self.parameters(), lr=self.cfg.optimizer.lr, weight_decay=0.0
)
elif self.cfg.optimizer.name == "adam":
optimizer = optim.Adam(
params=self.parameters(),
Expand Down Expand Up @@ -798,7 +803,7 @@ def select_network(cfg: DictConfig, device: str = None):
hidden_segments=cfg.net.hidden.segments,
normalization=normalization,
device=cfg.accelerator,
layer_type_in=cfg.net.input.get('layer_type', None),
layer_type_in=cfg.net.input.get("layer_type", None),
)
elif cfg.net.model_type == "high_order_conv":
conv = HighOrderFullyConvolutionalNetwork(
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ langchain = "^0.0.334"
hydra-core = "^1.3.2"
tensorboard = "^2.15.1"
lion-pytorch = "^0.1.2"
high-order-layers-torch = "^2.3.3"
high-order-layers-torch = "^2.4.0"

[tool.poetry.group.dev.dependencies]
black = "^23.11.0"
Expand Down

0 comments on commit 385e818

Please sign in to comment.