Skip to content

Commit

Permalink
Fixed set logic implementation for collections.
Browse files Browse the repository at this point in the history
  • Loading branch information
MicahGale committed Nov 28, 2024
1 parent 267a639 commit 4754bde
Showing 1 changed file with 76 additions and 32 deletions.
108 changes: 76 additions & 32 deletions montepy/numbered_object_collection.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright 2024, Battelle Energy Alliance, LLC All Rights Reserved.
from abc import ABC, abstractmethod
import itertools as it
import typing
import weakref

Expand Down Expand Up @@ -335,8 +336,14 @@ def _delete_hook(self, obj, **kwargs):

def __internal_append(self, obj, **kwargs):
"""
TODO
The internal append method.
This should always be called rather than manually added.
"""
if not isinstance(obj, self._obj_class):
raise TypeError(
f"Object must be of type: {self._obj_class.__name__}. {obj} given."
)
if obj.number in self.__num_cache:
if obj is self[obj.number]:
return
Expand All @@ -355,15 +362,36 @@ def __internal_delete(self, obj, **kwargs):
self._objects.remove(obj)
self._delete_hook(obj, **kwargs)

def add(self, obj):
# TODO type enforcement
# TODO propagate to Data Numbered
def add(self, obj: Numbered_MCNP_Object):
"""
Add the given object to this collection.
:param obj: The object to add.
:type obj: Numbered_MCNP_Object
:raises TypeError: if the object is of the wrong type.
:raises NumberConflictError: if this object's number is already in use in the collection.
"""
self.__internal_append(obj)

def update(self, objs):
# TODO type enforcement
# TODO propagate to Data Numbered
# not thread safe
"""
Add the given object to this collection.
:param obj: The object to add.
:type obj: Numbered_MCNP_Object
.. note::
This is not a thread-safe method.
:raises TypeError: if the object is of the wrong type.
:raises NumberConflictError: if this object's number is already in use in the collection.
"""
try:
iter(objs)
except TypeError:
raise TypeError(f"Objs must be an iterable. {objs} given.")
for obj in objs:
self.__internal_append(obj)

Expand Down Expand Up @@ -543,25 +571,31 @@ def __contains__(self, other):
return other in self._objects

def __set_logic(self, other, operator):
# TODO type enforcement
# force a num_cache update
"""
Takes another collection, and apply the operator to it, and returns a new instance.
Operator must be a callable that accepts a set of the numbers of self,
and another set for other's numbers.
"""
if not isinstance(other, type(self)):
raise TypeError(
f"Other side must be of the type {type(self).__name__}. {other} given."
)
self_nums = set(self.keys())
other_nums = set(other.keys())
new_nums = operator(self_nums, other_nums)
new_objs = []
# TODO should we verify all the objects are the same?
for obj in self:
new_objs = {}
# give preference to self
for obj in it.chain(other, self):
if obj.number in new_nums:
new_objs.append(obj)
return type(self)(new_objs)
new_objs[obj.number] = obj
return type(self)(list(new_objs.values()))

def __and__(self, other):
"""
Create set-like behavior
"""
return self.__set_logic(other, lambda a, b: a & b)

def __iand__(self, other):
# TODO make examples in doc strings
new_vals = self & other
self.__num_cache.clear()
self._objects.clear()
Expand All @@ -580,7 +614,7 @@ def __sub__(self, other):
return self.__set_logic(other, lambda a, b: a - b)

def __isub__(self, other):
excess_values = self - other
excess_values = self & other
for excess in excess_values:
del self[excess.number]
return self
Expand All @@ -596,7 +630,16 @@ def __ixor__(self, other):
return self

def __set_logic_test(self, other, operator):
# TODO type
"""
Takes another collection, and apply the operator to it, testing the logic of it.
Operator must be a callable that accepts a set of the numbers of self,
and another set for other's numbers.
"""
if not isinstance(other, type(self)):
raise TypeError(
f"Other side must be of the type {type(self).__name__}. {other} given."
)
self_nums = set(self.keys())
other_nums = set(other.keys())
return operator(self_nums, other_nums)
Expand All @@ -622,30 +665,31 @@ def isdisjoint(self, other):
def issuperset(self, other):
return self.__set_logic_test(other, lambda a, b: a.issuperset(b))

def __set_logic_multi(self, others, operator, iterate_all=False):
def __set_logic_multi(self, others, operator):
for other in others:
if not isinstance(other, type(self)):
raise TypeError(
f"Other argument must be of type {type(self).__name__}. {other} given."
)
self_nums = set(self.keys())
other_sets = []
for other in others:
other_sets.append(set(other.keys()))
valid_nums = operator(self_nums, *other_sets)
to_iterate = [self]
if iterate_all:
to_iterate += others
objs = []
for collection in to_iterate:
for obj in collection:
if obj.number in valid_nums:
objs.append(obj)
return type(self)(objs)
objs = {}
for obj in it.chain(*others, self):
if obj.number in valid_nums:
objs[obj.number] = obj
return type(self)(list(objs.values()))

def intersection(self, *others):
return self.__set_logic_multi(others, lambda a, b: a.intersection(b))
return self.__set_logic_multi(others, lambda a, *b: a.intersection(*b))

def union(self, *others):
return self.__set_logic_multi(others, lambda a, b: a.union(b))
return self.__set_logic_multi(others, lambda a, *b: a.union(*b))

def difference(self, *others):
return self.__set_logic_multi(others, lambda a, b: a.difference(b))
return self.__set_logic_multi(others, lambda a, *b: a.difference(*b))

def difference_update(self, *others):
new_vals = self.difference(*others)
Expand Down

0 comments on commit 4754bde

Please sign in to comment.