Skip to content

Commit

Permalink
Fix Model.find_idx signature
Browse files Browse the repository at this point in the history
  • Loading branch information
jinningwang committed Dec 13, 2024
1 parent e9320e7 commit 7c397cb
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 25 deletions.
36 changes: 12 additions & 24 deletions ams/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import warnings

from collections import OrderedDict
from typing import Iterable, Sized
from typing import Iterable

import numpy as np
from andes.core.common import Config
Expand All @@ -16,6 +16,7 @@

from ams.core.documenter import Documenter
from ams.core.var import Algeb
from ams.utils.func import validate_keys_values

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -260,7 +261,7 @@ def alter(self, src, idx, value, attr='v'):
else:
self.set(src, idx, attr=attr, value=value)

def find_idx(self, keys, values, allow_none=False, default=False):
def find_idx(self, keys, values, allow_none=False, default=False, allow_all=False):
"""
Find `idx` of devices whose values match the given pattern.
Expand All @@ -284,40 +285,27 @@ def find_idx(self, keys, values, allow_none=False, default=False):
list
indices of devices
"""
if isinstance(keys, str):
keys = (keys,)
if not isinstance(values, (int, float, str, np.floating)) and not isinstance(values, Iterable):
raise ValueError(f"value must be a string, scalar or an iterable, got {values}")

if len(values) > 0 and not isinstance(values[0], (list, tuple, np.ndarray)):
values = (values,)

elif isinstance(keys, Sized):
if not isinstance(values, Iterable):
raise ValueError(f"value must be an iterable, got {values}")

if len(values) > 0 and not isinstance(values[0], Iterable):
raise ValueError(f"if keys is an iterable, values must be an iterable of iterables. got {values}")

if len(keys) != len(values):
raise ValueError("keys and values must have the same length")
keys, values = validate_keys_values(keys, values)

v_attrs = [self.__dict__[key].v for key in keys]

idxes = []
for v_search in zip(*values):
v_idx = None
v_idx = []
for pos, v_attr in enumerate(zip(*v_attrs)):
if all([i == j for i, j in zip(v_search, v_attr)]):
v_idx = self.idx.v[pos]
break
if v_idx is None:
v_idx.append(self.idx.v[pos])
if not v_idx:
if allow_none is False:
raise IndexError(f'{list(keys)}={v_search} not found in {self.class_name}')
else:
v_idx = default
v_idx = [default]

idxes.append(v_idx)
if allow_all:
idxes.append(v_idx)
else:
idxes.append(v_idx[0])

return idxes

Expand Down
1 change: 0 additions & 1 deletion ams/models/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,6 @@ def find_idx(self, keys, values, allow_none=False, default=None, allow_all=False
indices_found = []
# `indices_found` contains found indices returned from all models of this group
for model in self.models.values():
print(model)
indices_found.append(model.find_idx(keys, values, allow_none=True, default=default, allow_all=True))

# --- find missing pairs ---
Expand Down

0 comments on commit 7c397cb

Please sign in to comment.