diff --git a/neuralnetlib/activations.py b/neuralnetlib/activations.py index 12f8384..2b5a67f 100644 --- a/neuralnetlib/activations.py +++ b/neuralnetlib/activations.py @@ -12,16 +12,19 @@ def derivative(self, x: np.ndarray) -> np.ndarray: raise NotImplementedError def get_config(self) -> dict: - return {} + return {"name": self.__class__.__name__} @staticmethod def from_config(config: dict): - name = config['name'] + name = config.get('name') + if not name: + raise ValueError('Config must contain "name" field') + + constructor_params = {k: v for k, v in config.items() + if k not in ['name', 'config']} for activation_class in ActivationFunction.__subclasses__(): if activation_class.__name__ == name: - constructor_params = {k: v for k, - v in config.items() if k != 'name'} return activation_class(**constructor_params) raise ValueError(f'Unknown activation function: {name}')