-
Notifications
You must be signed in to change notification settings - Fork 1
/
config.py
212 lines (174 loc) · 8.39 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
import argparse
from ast import Dict
from typing import Callable, List, Type, Optional, Any, Union
from collections import UserDict
from typing import Sequence
from argparse import ArgumentError
from collections.abc import Mapping, MutableMapping
import os
from distutils.util import strtobool
import copy
import collections.abc
from os import stat
import yaml
import yaml_utils
# global config instance
_config:'Config' = None
def deep_update(d:MutableMapping, u:Mapping, create_map:Callable[[],MutableMapping])\
->MutableMapping:
for k, v in u.items():
if isinstance(v, Mapping):
d[k] = deep_update(d.get(k, create_map()), v, create_map)
else:
d[k] = v
return d
def _save_to_file(d, filepath:Optional[str])->None:
with open(filepath, 'w') as f:
yaml.dump(d, f, default_flow_style=False)
print('Config saved to: ', filepath)
def _load_from_file(filepath:Optional[str])->None:
with open(filepath, 'r') as f:
config_yaml = yaml.load(f, Loader=yaml.Loader)
return config_yaml
class Config(UserDict):
def __init__(self, config_filepath:Optional[str]=None,
app_desc:Optional[str]=None, use_args=False,
param_args: Sequence = [], resolve_redirects=True) -> None:
"""Create config from specified files and args
Config is simply a dictionary of key, value map. The value can itself be
a dictionary so config can be hierarchical. This class allows to load
config from yaml. A special key '__include__' can specify another yaml
relative file path (or list of file paths) which will be loaded first
and the key-value pairs in the main file
will override the ones in include file. You can think of included file as
defaults provider. This allows to create one base config and then several
environment/experiment specific configs. On the top of that you can use
param_args to perform final overrides for a given run.
Keyword Arguments:
config_filepath {[str]} -- [Yaml file to load config from, could be names of files separated by semicolon which will be loaded in sequence oveeriding previous config] (default: {None})
app_desc {[str]} -- [app description that will show up in --help] (default: {None})
use_args {bool} -- [if true then command line parameters will override parameters from config files] (default: {False})
param_args {Sequence} -- [parameters specified as ['--key1',val1,'--key2',val2,...] which will override parameters from config file.] (default: {[]})
resolve_redirects -- [if True then _copy commands in yaml are executed]
"""
super(Config, self).__init__()
self.args, self.extra_args = None, []
if use_args:
# let command line args specify/override config file
parser = argparse.ArgumentParser(description=app_desc)
parser.add_argument('--config', type=str, default=None,
help='config filepath in yaml format, can be list separated by ;')
self.args, self.extra_args = parser.parse_known_args()
config_filepath = self.args.config or config_filepath
if config_filepath:
for filepath in config_filepath.strip().split(';'):
self._load_from_file(filepath.strip())
# Create a copy of ourselves and do the resolution over it.
# This resolved_conf then can be used to search for overrides that
# wouldn't have existed before resolution.
resolved_conf = copy.deepcopy(self)
if resolve_redirects:
yaml_utils.resolve_all(resolved_conf)
# Let's do final overrides from args
self._update_from_args(param_args, resolved_conf) # merge from params
self._update_from_args(self.extra_args, resolved_conf) # merge from command line
if resolve_redirects:
yaml_utils.resolve_all(self)
self.config_filepath = config_filepath
def _load_from_file(self, filepath:Optional[str])->None:
if filepath:
filepath = os.path.expanduser(os.path.expandvars(filepath))
filepath = os.path.abspath(filepath)
with open(filepath, 'r') as f:
config_yaml = yaml.load(f, Loader=yaml.Loader)
self._process_includes(config_yaml, filepath)
deep_update(self, config_yaml, lambda: Config(resolve_redirects=False))
print('Config loaded from: ', filepath)
def _process_includes(self, config_yaml, filepath:str):
if '__include__' in config_yaml:
# include could be file name or array of file names to apply in sequence
includes = config_yaml['__include__']
if isinstance(includes, str):
includes = [includes]
assert isinstance(includes, List), "'__include__' value must be string or list"
for include in includes:
include_filepath = os.path.join(os.path.dirname(filepath), include)
self._load_from_file(include_filepath)
def _update_from_args(self, args:Sequence, resolved_section:'Config')->None:
i = 0
while i < len(args)-1:
arg = args[i]
if arg.startswith(("--")):
path = arg[len("--"):].split('.')
i += Config._update_section(self, path, args[i+1], resolved_section)
else: # some other arg
i += 1
def to_dict(self)->dict:
return deep_update({}, self, lambda: dict()) # type: ignore
@staticmethod
def _update_section(section:'Config', path:List[str], val:Any, resolved_section:'Config')->int:
for p in range(len(path)-1):
sub_path = path[p]
if sub_path in resolved_section:
resolved_section = resolved_section[sub_path]
if not sub_path in section:
section[sub_path] = Config(resolve_redirects=False)
section = section[sub_path]
else:
return 1 # path not found, ignore this
key = path[-1] # final leaf node value
if key in resolved_section:
original_val, original_type = None, None
try:
original_val = resolved_section[key]
original_type = type(original_val)
if original_type == bool: # bool('False') is True :(
original_type = lambda x: strtobool(x)==1
section[key] = original_type(val)
except Exception as e:
raise KeyError(
f'The yaml key or command line argument "{key}" is likely not named correctly or value is of wrong data type. Error was occured when setting it to value "{val}".'
f'Originally it is set to {original_val} which is of type {original_type}.'
f'Original exception: {e}')
return 2 # path was found, increment arg pointer by 2 as we use up val
else:
return 1 # path not found, ignore this
def get_val(self, key, default_val):
return super().get(key, default_val)
@staticmethod
def set_inst(instance:'Config')->None:
global _config
_config = instance
@staticmethod
def get_inst()->'Config':
global _config
return _config
def get_conf(conf:Optional[Config]=None)->Config:
if conf:
return conf
return Config.get_inst()
def get_conf_common(conf:Optional[Config]=None)->Config:
return get_conf(conf)[f'common']
def get_conf_step(conf:Optional[Config]=None,step:int=1)->Config:
return get_conf(conf)[f'step{step}']
def get_conf_classifier(conf_step:Optional[Config])->Config:
return conf_step['classifier']
def get_conf_gen_features(conf_step:Optional[Config])->Config:
return conf_step['generate_features']
def get_conf_ood(conf_step:Optional[Config])->Config:
return conf_step['ood_detection']
def get_conf_attr(conf_step:Optional[Config])->Config:
return conf_step['attribution']
def get_conf_cluster(conf_step:Optional[Config])->Config:
return conf_step['clustering']
def get_conf_merge(conf_step:Optional[Config])->Config:
return conf_step['merge']
def get_conf_refine(conf_step:Optional[Config])->Config:
return conf_step['refine']
def update(d, u):
for k, v in u.items():
if isinstance(v, collections.abc.Mapping):
d[k] = update(d.get(k, {}), v)
else:
d[k] = v
return d