Skip to content

Commit

Permalink
fix: unify param extraction
Browse files Browse the repository at this point in the history
  • Loading branch information
FlickerSoul committed Nov 2, 2023
1 parent e865ab9 commit 816db6c
Showing 1 changed file with 25 additions and 37 deletions.
62 changes: 25 additions & 37 deletions src/gapper/core/test_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,22 +78,30 @@ def update(self, new_info: Dict[str, Any]) -> None:
class ParamExtractor:
def __init__(self, kwargs: Dict[str, Any]) -> None:
"""Initialize the gap test parameter."""
gap_params = {
k.value: kwargs.pop(k.value)
for k in GapReservedKeywords
if k.value in kwargs
}

if caught_residue := self.check_gap_kwargs_residue(kwargs):
raise ValueError(f"Unknown gap keyword arguments: {caught_residue}")

gap_params = type(self).extra_gap_info(kwargs)
self._param_info = type(self)._select_param_info(gap_params)

@property
def param_info(self) -> ParamInfo:
"""Return the parameter information."""
return self._param_info

@classmethod
def extra_gap_info(
cls, kwargs: Dict[str, Any], check_residue: bool = True
) -> Dict[str, Any]:
gap_kwargs = {
k.value: kwargs.pop(k.value)
for k in GapReservedKeywords
if k.value in kwargs
}

if check_residue:
if caught_residue := cls.check_gap_kwargs_residue(kwargs):
raise ValueError(f"Unknown gap keyword arguments: {caught_residue}")

return gap_kwargs

@staticmethod
def _select_param_info(kwargs: Dict[str, Any]) -> ParamInfo:
"""Select the parameter information to fill in."""
Expand Down Expand Up @@ -256,17 +264,12 @@ def format(self, with_gap_kwargs: bool = False) -> str:
return f"({kwargs_format})"

@classmethod
def bind_override(
def bind(
cls,
*,
gap_override_check: CustomEqualityCheckFn | None = None,
gap_override_test: CustomTestFn | None = None,
**kwargs,
) -> partial[TestParam]:
return partial(
cls,
gap_override_check=gap_override_check,
gap_override_test=gap_override_test,
)
gap_kwargs = ParamExtractor.extra_gap_info(kwargs)
return partial(cls, **gap_kwargs)

def __eq__(self, other: TestParam) -> bool:
if isinstance(other, TestParam):
Expand Down Expand Up @@ -403,14 +406,7 @@ def __init__(
raise warnings.warn("gap_zip is deprecated.")

# pop gap keywords out
gap_kwargs_dict = {
kwd.value: kwargs.pop(kwd.value)
for kwd in GapReservedKeywords
if kwd.value in kwargs
}

if caught_residue := ParamExtractor.check_gap_kwargs_residue(kwargs):
raise ValueError(f"Unknown gap keyword arguments: {caught_residue}")
gap_kwargs_dict = ParamExtractor.extra_gap_info(kwargs)

if gap_params:
self.final_params: List[TestParam] = type(self).parse_params(
Expand Down Expand Up @@ -580,17 +576,9 @@ def __call__(
return prob

@classmethod
def bind_override(
cls,
*,
gap_override_check: CustomEqualityCheckFn | None = None,
gap_override_test: CustomTestFn | None = None,
) -> partial[TestParamBundle]:
return partial(
cls,
gap_override_check=gap_override_check,
gap_override_test=gap_override_test,
)
def bind(cls, **kwargs) -> partial[TestParamBundle]:
gap_kwargs = ParamExtractor.extra_gap_info(kwargs)
return partial(cls, **gap_kwargs)


test_cases = TestParamBundle
Expand Down

0 comments on commit 816db6c

Please sign in to comment.