diff --git a/montepy/numbered_object_collection.py b/montepy/numbered_object_collection.py index 5de48ce7..2ef6254c 100644 --- a/montepy/numbered_object_collection.py +++ b/montepy/numbered_object_collection.py @@ -1,5 +1,6 @@ # Copyright 2024, Battelle Energy Alliance, LLC All Rights Reserved. from abc import ABC, abstractmethod +import itertools as it import typing import weakref @@ -335,8 +336,14 @@ def _delete_hook(self, obj, **kwargs): def __internal_append(self, obj, **kwargs): """ - TODO + The internal append method. + + This should always be called rather than manually added. """ + if not isinstance(obj, self._obj_class): + raise TypeError( + f"Object must be of type: {self._obj_class.__name__}. {obj} given." + ) if obj.number in self.__num_cache: if obj is self[obj.number]: return @@ -355,15 +362,36 @@ def __internal_delete(self, obj, **kwargs): self._objects.remove(obj) self._delete_hook(obj, **kwargs) - def add(self, obj): - # TODO type enforcement - # TODO propagate to Data Numbered + def add(self, obj: Numbered_MCNP_Object): + """ + Add the given object to this collection. + + :param obj: The object to add. + :type obj: Numbered_MCNP_Object + + :raises TypeError: if the object is of the wrong type. + :raises NumberConflictError: if this object's number is already in use in the collection. + """ self.__internal_append(obj) def update(self, objs): - # TODO type enforcement - # TODO propagate to Data Numbered - # not thread safe + """ + Add the given object to this collection. + + :param obj: The object to add. + :type obj: Numbered_MCNP_Object + + .. note:: + + This is not a thread-safe method. + + :raises TypeError: if the object is of the wrong type. + :raises NumberConflictError: if this object's number is already in use in the collection. + """ + try: + iter(objs) + except TypeError: + raise TypeError(f"Objs must be an iterable. {objs} given.") for obj in objs: self.__internal_append(obj) @@ -543,25 +571,31 @@ def __contains__(self, other): return other in self._objects def __set_logic(self, other, operator): - # TODO type enforcement - # force a num_cache update + """ + Takes another collection, and apply the operator to it, and returns a new instance. + + Operator must be a callable that accepts a set of the numbers of self, + and another set for other's numbers. + """ + if not isinstance(other, type(self)): + raise TypeError( + f"Other side must be of the type {type(self).__name__}. {other} given." + ) self_nums = set(self.keys()) other_nums = set(other.keys()) new_nums = operator(self_nums, other_nums) - new_objs = [] - # TODO should we verify all the objects are the same? - for obj in self: + new_objs = {} + # give preference to self + for obj in it.chain(other, self): if obj.number in new_nums: - new_objs.append(obj) - return type(self)(new_objs) + new_objs[obj.number] = obj + return type(self)(list(new_objs.values())) def __and__(self, other): - """ - Create set-like behavior - """ return self.__set_logic(other, lambda a, b: a & b) def __iand__(self, other): + # TODO make examples in doc strings new_vals = self & other self.__num_cache.clear() self._objects.clear() @@ -580,7 +614,7 @@ def __sub__(self, other): return self.__set_logic(other, lambda a, b: a - b) def __isub__(self, other): - excess_values = self - other + excess_values = self & other for excess in excess_values: del self[excess.number] return self @@ -596,7 +630,16 @@ def __ixor__(self, other): return self def __set_logic_test(self, other, operator): - # TODO type + """ + Takes another collection, and apply the operator to it, testing the logic of it. + + Operator must be a callable that accepts a set of the numbers of self, + and another set for other's numbers. + """ + if not isinstance(other, type(self)): + raise TypeError( + f"Other side must be of the type {type(self).__name__}. {other} given." + ) self_nums = set(self.keys()) other_nums = set(other.keys()) return operator(self_nums, other_nums) @@ -622,30 +665,31 @@ def isdisjoint(self, other): def issuperset(self, other): return self.__set_logic_test(other, lambda a, b: a.issuperset(b)) - def __set_logic_multi(self, others, operator, iterate_all=False): + def __set_logic_multi(self, others, operator): + for other in others: + if not isinstance(other, type(self)): + raise TypeError( + f"Other argument must be of type {type(self).__name__}. {other} given." + ) self_nums = set(self.keys()) other_sets = [] for other in others: other_sets.append(set(other.keys())) valid_nums = operator(self_nums, *other_sets) - to_iterate = [self] - if iterate_all: - to_iterate += others - objs = [] - for collection in to_iterate: - for obj in collection: - if obj.number in valid_nums: - objs.append(obj) - return type(self)(objs) + objs = {} + for obj in it.chain(*others, self): + if obj.number in valid_nums: + objs[obj.number] = obj + return type(self)(list(objs.values())) def intersection(self, *others): - return self.__set_logic_multi(others, lambda a, b: a.intersection(b)) + return self.__set_logic_multi(others, lambda a, *b: a.intersection(*b)) def union(self, *others): - return self.__set_logic_multi(others, lambda a, b: a.union(b)) + return self.__set_logic_multi(others, lambda a, *b: a.union(*b)) def difference(self, *others): - return self.__set_logic_multi(others, lambda a, b: a.difference(b)) + return self.__set_logic_multi(others, lambda a, *b: a.difference(*b)) def difference_update(self, *others): new_vals = self.difference(*others)