Skip to content

Commit

Permalink
add __add__ method to StrainCollection class
Browse files Browse the repository at this point in the history
- add the `__add__` magic method
-  add unit tests for this magic method
  • Loading branch information
CunliangGeng committed Oct 30, 2023
1 parent 0815f48 commit e53ba6e
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 0 deletions.
10 changes: 10 additions & 0 deletions src/nplinker/strain_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,16 @@ def __eq__(self, other) -> bool:
and self._strain_dict_name == other._strain_dict_name)
return NotImplemented

def __add__(self, other) -> 'StrainCollection':
if isinstance(other, StrainCollection):
sc = StrainCollection()
for strain in self._strains:
sc.add(strain)
for strain in other._strains:
sc.add(strain)
return sc
return NotImplemented

def __contains__(self, item: Strain) -> bool:
"""Check if the strain collection contains the given Strain object.
"""
Expand Down
28 changes: 28 additions & 0 deletions tests/test_strain_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,34 @@ def test_eq(collection: StrainCollection, strain: Strain):
assert collection == other


def test_magic_add(collection: StrainCollection, strain: Strain):
other = StrainCollection()
# same id, same alias
other.add(strain)
# same id, different alias
strain1 = Strain("strain_1")
strain1.add_alias("strain_1_b")
other.add(strain1)
# different id, same alias
strain2 = Strain("strain_2")
strain2.add_alias("strain_2_a")
other.add(strain2)

assert collection + other == other + collection

actual = collection + other
assert len(actual) == 2
assert strain in actual
assert strain1 in actual
assert strain2 in actual
assert len(actual._strain_dict_name) == 5
assert actual._strain_dict_name["strain_1"] == [strain]
assert actual._strain_dict_name["strain_1_a"] == [strain]
assert actual._strain_dict_name["strain_1_b"] == [strain]
assert actual._strain_dict_name["strain_2"] == [strain2]
assert actual._strain_dict_name["strain_2_a"] == [strain2]


def test_contains(collection: StrainCollection, strain: Strain):
assert strain in collection
strain2 = Strain("strain_2")
Expand Down

0 comments on commit e53ba6e

Please sign in to comment.