Skip to content

Commit

Permalink
Merge branch 'dmc-bugfix' of github.com:arup-group/mc into dmc-bugfix
Browse files Browse the repository at this point in the history
  • Loading branch information
syhwawa committed Nov 8, 2023
2 parents 1b36af3 + 441236b commit 8a84e88
Show file tree
Hide file tree
Showing 14 changed files with 1,084 additions and 729 deletions.
2 changes: 1 addition & 1 deletion mc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@
__version__ = "1.0.2"
_ROOT = Path(os.path.abspath(os.path.dirname(__file__)))

_DEFAULTS_DIR = _ROOT / 'default_data'
_DEFAULTS_DIR = _ROOT / "default_data"
if not _DEFAULTS_DIR.is_dir():
raise NotADirectoryError(f"Default data dir not found at {_DEFAULTS_DIR}")
44 changes: 25 additions & 19 deletions mc/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,9 +571,9 @@ def __getitem__(self, key):
return self.params[key].value
if key in self.parametersets:
return self.parametersets[key]
if key + ":default" in self.parametersets:
if key + "::default" in self.parametersets:
print("WARNING assuming 'default' required")
return self.parametersets[key + ":default"]
return self.parametersets[key + "::default"]

# try to collect list of paramsets
collected = []
Expand Down Expand Up @@ -654,9 +654,9 @@ def get(self, key, default=None):
return self.params[key].value
if key in self.parametersets:
return self.parametersets[key]
if key + ":default" in self.parametersets:
if key + "::default" in self.parametersets:
print("WARNING assuming 'default' required")
return self.parametersets[key + ":default"]
return self.parametersets[key + "::default"]

return default

Expand Down Expand Up @@ -733,9 +733,9 @@ def __getitem__(self, key):
return self.params[key].value
if key in self.parametersets:
return self.parametersets[key]
if key + ":default" in self.parametersets:
print("WARNING assuming '<parameterset>:default' required")
return self.parametersets[key + ":default"]
if key + "::default" in self.parametersets:
print("WARNING assuming '<parameterset>::default' required")
return self.parametersets[key + "::default"]

# try to collect list of paramsets
collected = []
Expand Down Expand Up @@ -834,9 +834,9 @@ def get(self, key, default=None):
return self.params[key].value
if key in self.parametersets:
return self.parametersets[key]
if key + ":default" in self.parametersets:
if key + "::default" in self.parametersets:
print("WARNING assuming 'default' required")
return self.parametersets[key + ":default"]
return self.parametersets[key + "::default"]

return default

Expand Down Expand Up @@ -886,7 +886,7 @@ def __eq__(self, other):
return True


def specials_snap(a, b, divider=":", ignore="*"):
def specials_snap(a, b, divider="::", ignore="*"):
"""
Special function to check for key matches with consideration of
special character '*' that represents 'all'.
Expand Down Expand Up @@ -938,7 +938,7 @@ def json_path(path: Path) -> bool:
return False


def build_paramset_key(elem: et.Element) -> Tuple[str, str, str]:
def build_paramset_key(elem: et.Element, seperator: str = "::") -> Tuple[str, str, str]:
"""
Function to extract the appropriate suffix from a given parameterset xml element. Returns the
element type (either for subpopulation, mode or activity) and new key. This key is used to
Expand All @@ -962,7 +962,7 @@ def build_paramset_key(elem: et.Element) -> Tuple[str, str, str]:
(uid,) = [
p.attrib["value"] for p in elem.xpath("./param[@name='activityType']")
]
key = paramset_type + ":" + uid
key = paramset_type + seperator + uid
return paramset_type, key, uid

if paramset_type in [
Expand All @@ -972,14 +972,14 @@ def build_paramset_key(elem: et.Element) -> Tuple[str, str, str]:
"modeRangeRestrictionSet",
]:
(uid,) = [p.attrib["value"] for p in elem.xpath("./param[@name='mode']")]
key = paramset_type + ":" + uid
key = paramset_type + seperator + uid
return paramset_type, key, uid

if paramset_type in ["scoringParameters"]:
(uid,) = [
p.attrib["value"] for p in elem.xpath("./param[@name='subpopulation']")
]
key = paramset_type + ":" + uid
key = paramset_type + seperator + uid
return paramset_type, key, uid

if paramset_type in ["strategysettings"]:
Expand All @@ -989,15 +989,21 @@ def build_paramset_key(elem: et.Element) -> Tuple[str, str, str]:
(strategy,) = [
p.attrib["value"] for p in elem.xpath("./param[@name='strategyName']")
]
uid = subpop + ":" + strategy
key = paramset_type + ":" + uid
uid = subpop + seperator + strategy
key = paramset_type + seperator + uid
return paramset_type, key, uid

if paramset_type in ["modeMapping"]:
(uid,) = [
p.attrib["value"] for p in elem.xpath("./param[@name='passengerMode']")
]
key = paramset_type + ":" + uid
key = paramset_type + seperator + uid
return paramset_type, key, uid

if ":" in paramset_type:
"""special cases fpr selector:MultinomialLogit and modeAvailability:Car from DMC mod"""
uid = paramset_type.split(":")[-1]
key = paramset_type + seperator + uid
return paramset_type, key, uid

raise ValueError(
Expand Down Expand Up @@ -1025,13 +1031,13 @@ def sets_diff(self: list, other: list, name: str, loc: str) -> list:
return diffs


def get_paramset_type(key: str) -> str:
def get_paramset_type(key: str, seperator: str = "::") -> str:
"""
Return parameterset type from unique key.
:param key: str
:return: str
"""
return key.split(":")[0]
return key.split(seperator)[0]


def get_params_search(dic: dict, target: str) -> dict:
Expand Down
Loading

0 comments on commit 8a84e88

Please sign in to comment.