Skip to content

Commit

Permalink
Simply generic getting function to loss only
Browse files Browse the repository at this point in the history
  • Loading branch information
JMGaljaard committed May 11, 2022
1 parent dabcf8e commit 50ab397
Showing 1 changed file with 4 additions and 24 deletions.
28 changes: 4 additions & 24 deletions fltk/util/config/learning_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,12 +167,6 @@ def from_yaml(path: Path):
"CROSSENTROPYLOSS": torch.nn.CrossEntropyLoss,
"HUBERLOSS" : torch.nn.HuberLoss
}
_available_optimizer: Dict[str, Type[torch.optim.Optimizer]] = {
"SGD": torch.optim.SGD,
"ADAM": torch.optim.Adam,
"ADAMW": torch.optim.AdamW
}


@dataclass_json
@dataclass
Expand All @@ -194,28 +188,14 @@ class DistLearningConfig(LearningConfig): # pylint: disable=too-many-instance-a
min_lr: float
seed: int

@staticmethod
def __safe_get(lookup: Dict[str, T], keyword: str) -> T:
"""
Static function to 'safe' get elements from a dictionary, to prevent issues with Capitalization in the code.
@param lookup: Lookup dictionary to 'safe get' from.
@type lookup: dict
@param keyword: Keyword to 'get' from the Lookup dictionary.
@type keyword: str
@return: Lookup value from 'safe get' request.
@rtype: T
"""
safe_keyword = str.upper(keyword)
if safe_keyword not in lookup:
logging.fatal(f"Cannot find configuration parameter {keyword} in dictionary.")
return lookup.get(safe_keyword)

def get_loss(self) -> Type:
"""
Function to obtain the loss function Type that was given via commandline to be used during the training
execution.
@return: Type corresponding to the loss function that was passed as argument.
@rtype: Type
"""
return self.__safe_get(_available_loss, self.loss)

safe_keyword = str.upper(self.loss)
if safe_keyword not in _available_loss:
logging.fatal(f"Cannot find configuration parameter {self.loss} in dictionary.")
return _available_loss.get(safe_keyword)

0 comments on commit 50ab397

Please sign in to comment.