diff --git a/ams/core/model.py b/ams/core/model.py index 9053c697..371d3dd3 100644 --- a/ams/core/model.py +++ b/ams/core/model.py @@ -6,7 +6,7 @@ import warnings from collections import OrderedDict -from typing import Iterable, Sized +from typing import Iterable import numpy as np from andes.core.common import Config @@ -16,6 +16,7 @@ from ams.core.documenter import Documenter from ams.core.var import Algeb +from ams.utils.func import validate_keys_values logger = logging.getLogger(__name__) @@ -260,7 +261,7 @@ def alter(self, src, idx, value, attr='v'): else: self.set(src, idx, attr=attr, value=value) - def find_idx(self, keys, values, allow_none=False, default=False): + def find_idx(self, keys, values, allow_none=False, default=False, allow_all=False): """ Find `idx` of devices whose values match the given pattern. @@ -284,40 +285,27 @@ def find_idx(self, keys, values, allow_none=False, default=False): list indices of devices """ - if isinstance(keys, str): - keys = (keys,) - if not isinstance(values, (int, float, str, np.floating)) and not isinstance(values, Iterable): - raise ValueError(f"value must be a string, scalar or an iterable, got {values}") - if len(values) > 0 and not isinstance(values[0], (list, tuple, np.ndarray)): - values = (values,) - - elif isinstance(keys, Sized): - if not isinstance(values, Iterable): - raise ValueError(f"value must be an iterable, got {values}") - - if len(values) > 0 and not isinstance(values[0], Iterable): - raise ValueError(f"if keys is an iterable, values must be an iterable of iterables. got {values}") - - if len(keys) != len(values): - raise ValueError("keys and values must have the same length") + keys, values = validate_keys_values(keys, values) v_attrs = [self.__dict__[key].v for key in keys] idxes = [] for v_search in zip(*values): - v_idx = None + v_idx = [] for pos, v_attr in enumerate(zip(*v_attrs)): if all([i == j for i, j in zip(v_search, v_attr)]): - v_idx = self.idx.v[pos] - break - if v_idx is None: + v_idx.append(self.idx.v[pos]) + if not v_idx: if allow_none is False: raise IndexError(f'{list(keys)}={v_search} not found in {self.class_name}') else: - v_idx = default + v_idx = [default] - idxes.append(v_idx) + if allow_all: + idxes.append(v_idx) + else: + idxes.append(v_idx[0]) return idxes diff --git a/ams/models/group.py b/ams/models/group.py index 2a7b5458..f944618c 100644 --- a/ams/models/group.py +++ b/ams/models/group.py @@ -241,7 +241,6 @@ def find_idx(self, keys, values, allow_none=False, default=None, allow_all=False indices_found = [] # `indices_found` contains found indices returned from all models of this group for model in self.models.values(): - print(model) indices_found.append(model.find_idx(keys, values, allow_none=True, default=default, allow_all=True)) # --- find missing pairs ---