Skip to content

Commit

Permalink
Fixed tests classes to use unittest
Browse files Browse the repository at this point in the history
  • Loading branch information
gareth-walley committed Aug 28, 2024
1 parent cc63212 commit 9a361ea
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 36 deletions.
16 changes: 7 additions & 9 deletions src/tests/test_ceg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,26 @@
# pylint: disable=protected-access
import re
from pathlib import Path
from typing import Dict, List, Mapping, Tuple
from typing import Dict, Mapping
import unittest
from unittest.mock import Mock, patch
import networkx as nx
import pandas as pd
import pytest
import pytest_mock
from cegpy import StagedTree, ChainEventGraph
from cegpy.graphs._ceg import (
CegAlreadyGenerated,
_merge_edge_data,
)


class TestMockedCEGMethods:
class TestMockedCEGMethods(unittest.TestCase):
"""Tests that Mock functions in ChainEventGraph"""

node_prefix = "w"
sink_suffix = "∞"
staged: StagedTree

def setup(self):
def setUp(self):
"""Test setup"""
df_path = (
Path(__file__)
Expand All @@ -34,12 +32,12 @@ def setup(self):
self.staged = StagedTree(dataframe=pd.read_excel(df_path))
self.staged.calculate_AHC_transitions()

def test_generate_argument(self, mocker: pytest_mock.MockerFixture):
@patch("cegpy.graphs._ceg.ChainEventGraph.generate", autospec=True)
def test_generate_argument(self, generate_mock: Mock):
"""When ChainEventGraph called with generate, the .generate()
method is called."""
mocker.patch("cegpy.graphs._ceg.ChainEventGraph.generate")
ceg = ChainEventGraph(self.staged, generate=True)
ceg.generate.assert_called_once() # pylint: disable=no-member
ChainEventGraph(self.staged, generate=True)
generate_mock.assert_called_once() # pylint: disable=no-member


class TestUnitCEG(unittest.TestCase):
Expand Down
4 changes: 2 additions & 2 deletions src/tests/test_ceg_reducer.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,8 +443,8 @@ def test_str_out(self):
assert str(node) in str_rep


class TestReducedCEGTwo(object):
def setup(self):
class TestReducedCEGTwo(unittest.TestCase):
def setUp(self):
G = nx.MultiDiGraph()
self.init_nodes = ["w0", "w1", "w2", "w3", "w4", "w5", "w6", "w_infinity"]
self.init_edges = [
Expand Down
38 changes: 16 additions & 22 deletions src/tests/test_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ def test_incorrect_sampling_zero_fails(self):
)


class TestEventTree:
def setup(self):
class TestEventTree(unittest.TestCase):
def setUp(self):
df_path = (
Path(__file__)
.resolve()
Expand Down Expand Up @@ -148,8 +148,8 @@ def test_node_colours(self) -> None:
assert event_node_colours[0] == "lightgrey"


class TestIntegration:
def setup(self):
class TestIntegration(unittest.TestCase):
def setUp(self):
# stratified dataset
med_df_path = (
Path(__file__)
Expand Down Expand Up @@ -198,8 +198,8 @@ def check_list_contains_strings(str_list) -> bool:
assert isinstance(elem, str)


class TestUsecase:
def setup(self):
class TestUsecase(unittest.TestCase):
def setUp(self):
# stratified dataset
med_df_path = (
Path(__file__)
Expand Down Expand Up @@ -232,7 +232,6 @@ def test_fall_cats_per_var(self):


class TestChangingDataFrame:

@pytest.fixture
def med_df(self):
# stratified dataset
Expand All @@ -247,9 +246,7 @@ def med_df(self):
@pytest.fixture
def med_et(self, med_df):
med_s_z_paths = None
med_et = EventTree(
dataframe=med_df, sampling_zero_paths=med_s_z_paths
)
med_et = EventTree(dataframe=med_df, sampling_zero_paths=med_s_z_paths)
return med_et

@pytest.fixture
Expand All @@ -263,10 +260,7 @@ def fall_df(self):
@pytest.fixture
def fall_et(self, fall_df):
self.fall_s_z_paths = None
return EventTree(
dataframe=fall_df, sampling_zero_paths=self.fall_s_z_paths
)

return EventTree(dataframe=fall_df, sampling_zero_paths=self.fall_s_z_paths)

def test_add_empty_column(self, fall_et, med_df, med_et, fall_df) -> None:
# adding empty column
Expand Down Expand Up @@ -325,8 +319,8 @@ def test_add_same_column_int(self, fall_et, med_df, med_et, fall_df) -> None:
assert len(fall_add_same_et.leaves) == len(fall_et.leaves)


class TestMissingLabels:
def setup(self):
class TestMissingLabels(unittest.TestCase):
def setUp(self):
array = [
np.array(["1", "NotANum", "Recover"]),
np.array(["1", "Trt1", "NotANum"]),
Expand Down Expand Up @@ -538,8 +532,8 @@ def test_complete_case_reduction(self) -> None:
assert df_et.dataframe.equals(expected_df) is True


class TestVariablesFiltered:
def setup(self):
class TestVariablesFiltered(unittest.TestCase):
def setUp(self):
array = [
np.array(["1", "NotANum", "Recover"]),
np.array(["1", "Trt1", "NotANum"]),
Expand All @@ -548,7 +542,7 @@ def setup(self):
np.array(["1", "Trt1", "Recover"]),
np.array(["1", "Trt2", "Recover"]),
np.array(["1", "Trt2", "Dont Recover"]),
np.array(["1", np.NaN, "Dont Recover"]),
np.array(["1", np.nan, "Dont Recover"]),
]

self.df = pd.DataFrame(array)
Expand All @@ -575,8 +569,8 @@ def test_pd_nans_filtered_with_missing(self) -> None:
assert df_et.categories_per_variable == expected_categories


class TestStageColours:
def setup(self):
class TestStageColours(unittest.TestCase):
def setUp(self):
array = [
np.array(["1", "NotANum", "Recover"]),
np.array(["1", "Trt1", "NotANum"]),
Expand All @@ -585,7 +579,7 @@ def setup(self):
np.array(["1", "Trt1", "Recover"]),
np.array(["1", "Trt2", "Recover"]),
np.array(["1", "Trt2", "Dont Recover"]),
np.array(["1", np.NaN, "Dont Recover"]),
np.array(["1", np.nan, "Dont Recover"]),
]

self.df = pd.DataFrame(array)
Expand Down
4 changes: 1 addition & 3 deletions src/tests/test_staged.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@ def med_st(self):
)
med_s_z_paths = None
med_df = pd.read_excel(med_df_path)
med_st = StagedTree(
dataframe=med_df, sampling_zero_paths=med_s_z_paths
)
med_st = StagedTree(dataframe=med_df, sampling_zero_paths=med_s_z_paths)
return med_st

def test_run_ahc_before_figure(self, med_st) -> None:
Expand Down

0 comments on commit 9a361ea

Please sign in to comment.