diff --git a/config/high_order_interpolation.yaml b/config/high_order_interpolation.yaml index 263ee34..9279793 100644 --- a/config/high_order_interpolation.yaml +++ b/config/high_order_interpolation.yaml @@ -67,7 +67,7 @@ prompts: topk: 3 num_predict: 100 defaults: - - optimizer: lion + - optimizer: sparse_lion - net: small #- override hydra/sweeper: nevergrad # hydra: diff --git a/config/optimizer/sparse_lion.yaml b/config/optimizer/sparse_lion.yaml new file mode 100644 index 0000000..fd5864d --- /dev/null +++ b/config/optimizer/sparse_lion.yaml @@ -0,0 +1,7 @@ +name: sparse_lion +# lr supposed to be 3 to 5 times smaller than adam +lr: 1e-5 +patience: 5 +factor: 0.1 +gamma: 0.9 +scheduler: plateau # exponential diff --git a/language_interpolation/networks.py b/language_interpolation/networks.py index 90956a5..20465a3 100644 --- a/language_interpolation/networks.py +++ b/language_interpolation/networks.py @@ -19,6 +19,7 @@ import logging import time from lion_pytorch import Lion +from high_order_layers_torch.sparse_optimizers import SparseLion logger = logging.getLogger(__name__) logging.basicConfig(level=logging.DEBUG) @@ -160,6 +161,10 @@ def configure_optimizers(self): optimizer = Lion( self.parameters(), lr=self.cfg.optimizer.lr, weight_decay=0.0 ) + elif self.cfg.optimizer.name == "sparse_lion": + optimizer = SparseLion( + self.parameters(), lr=self.cfg.optimizer.lr, weight_decay=0.0 + ) elif self.cfg.optimizer.name == "adam": optimizer = optim.Adam( params=self.parameters(), @@ -798,7 +803,7 @@ def select_network(cfg: DictConfig, device: str = None): hidden_segments=cfg.net.hidden.segments, normalization=normalization, device=cfg.accelerator, - layer_type_in=cfg.net.input.get('layer_type', None), + layer_type_in=cfg.net.input.get("layer_type", None), ) elif cfg.net.model_type == "high_order_conv": conv = HighOrderFullyConvolutionalNetwork( diff --git a/pyproject.toml b/pyproject.toml index a8cf7de..ae88fde 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,7 @@ langchain = "^0.0.334" hydra-core = "^1.3.2" tensorboard = "^2.15.1" lion-pytorch = "^0.1.2" -high-order-layers-torch = "^2.3.3" +high-order-layers-torch = "^2.4.0" [tool.poetry.group.dev.dependencies] black = "^23.11.0"