Skip to content

Commit

Permalink
Fix extrapolation.domain_slice() for mixed extrapolations
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Aug 18, 2024
1 parent c48229e commit 3ddc5f4
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion phiml/math/extrapolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1246,6 +1246,7 @@ def __init__(self, ext_by_boundary: Dict[str, Extrapolation]):
"""
super().__init__(pad_rank=None)
assert all(isinstance(e, Extrapolation) for e in ext_by_boundary.values())
assert all(not isinstance(e, _MixedExtrapolation) for e in ext_by_boundary.values()), f"Nested mixed extrapolations not supported"
assert all(isinstance(k, str) for k in ext_by_boundary.keys())
assert len(set(ext_by_boundary.values())) >= 2, f"Extrapolation can be simplified: {ext_by_boundary}"
self.ext = ext_by_boundary
Expand Down Expand Up @@ -1336,7 +1337,10 @@ def __getitem__(self, item):
return combine_sides({b: ext._getitem_with_domain(item, b[:-1], b.endswith('+'), self._dims) for b, ext in self.ext.items()})

def _getitem_with_domain(self, item: dict, dim: str, upper_edge: bool, all_dims: Sequence[str]) -> 'Extrapolation':
return combine_sides({b: ext._getitem_with_domain(item, b[:-1], b.endswith('+'), all_dims) for b, ext in self.ext.items()})
for b, ext in self.ext.items():
if b in [dim, dim+('+' if upper_edge else '-')]:
return ext._getitem_with_domain(item, b[:-1], b.endswith('+'), all_dims)
raise KeyError((dim, upper_edge))

@property
def _dims(self):
Expand Down

0 comments on commit 3ddc5f4

Please sign in to comment.