From 50ab39780c0e32e51714f2213f5ed763907979c3 Mon Sep 17 00:00:00 2001 From: Jeroen Galjaard Date: Wed, 11 May 2022 11:44:56 +0200 Subject: [PATCH] Simply generic getting function to loss only --- fltk/util/config/learning_config.py | 28 ++++------------------------ 1 file changed, 4 insertions(+), 24 deletions(-) diff --git a/fltk/util/config/learning_config.py b/fltk/util/config/learning_config.py index 7d644333..5cfaf15b 100644 --- a/fltk/util/config/learning_config.py +++ b/fltk/util/config/learning_config.py @@ -167,12 +167,6 @@ def from_yaml(path: Path): "CROSSENTROPYLOSS": torch.nn.CrossEntropyLoss, "HUBERLOSS" : torch.nn.HuberLoss } -_available_optimizer: Dict[str, Type[torch.optim.Optimizer]] = { - "SGD": torch.optim.SGD, - "ADAM": torch.optim.Adam, - "ADAMW": torch.optim.AdamW -} - @dataclass_json @dataclass @@ -194,22 +188,6 @@ class DistLearningConfig(LearningConfig): # pylint: disable=too-many-instance-a min_lr: float seed: int - @staticmethod - def __safe_get(lookup: Dict[str, T], keyword: str) -> T: - """ - Static function to 'safe' get elements from a dictionary, to prevent issues with Capitalization in the code. - @param lookup: Lookup dictionary to 'safe get' from. - @type lookup: dict - @param keyword: Keyword to 'get' from the Lookup dictionary. - @type keyword: str - @return: Lookup value from 'safe get' request. - @rtype: T - """ - safe_keyword = str.upper(keyword) - if safe_keyword not in lookup: - logging.fatal(f"Cannot find configuration parameter {keyword} in dictionary.") - return lookup.get(safe_keyword) - def get_loss(self) -> Type: """ Function to obtain the loss function Type that was given via commandline to be used during the training @@ -217,5 +195,7 @@ def get_loss(self) -> Type: @return: Type corresponding to the loss function that was passed as argument. @rtype: Type """ - return self.__safe_get(_available_loss, self.loss) - + safe_keyword = str.upper(self.loss) + if safe_keyword not in _available_loss: + logging.fatal(f"Cannot find configuration parameter {self.loss} in dictionary.") + return _available_loss.get(safe_keyword)