Skip to content

Commit

Permalink
Merge pull request #136 from g-walley/g-walley/issue133
Browse files Browse the repository at this point in the history
G walley/issue133
  • Loading branch information
g-walley authored Mar 5, 2023
2 parents b69ac26 + 167d1d2 commit 41e71ac
Show file tree
Hide file tree
Showing 8 changed files with 38 additions and 194 deletions.
4 changes: 0 additions & 4 deletions src/cegpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,10 @@
See documentation of more information.
"""
import logging

from cegpy.graphs._ceg import ChainEventGraph
from cegpy.graphs._ceg_reducer import ChainEventGraphReducer
from cegpy.trees._event import EventTree
from cegpy.trees._staged import StagedTree

logging.basicConfig(level=logging.WARN)
logger = logging.getLogger("cegpy")

__version__ = "1.0.5"
30 changes: 9 additions & 21 deletions src/cegpy/graphs/_ceg.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from collections import defaultdict
from copy import deepcopy
import itertools as it
import logging
from typing import (
Any,
Dict,
Expand All @@ -23,8 +22,6 @@
from cegpy.utilities._util import generate_filename_and_mkdir
from cegpy.trees._staged import StagedTree

logger = logging.getLogger("cegpy.chain_event_graph")


class CegAlreadyGenerated(Exception):
"""Raised when a CEG is generated twice."""
Expand Down Expand Up @@ -207,14 +204,6 @@ def _generate_dot_graph(self, edge_info="probability"):
if edge_info in self._edge_attributes:
edge_info_dict = nx.get_edge_attributes(self, edge_info)
else:
logger.warning(
"edge_info '%s' does not exist for the %s class. "
"Using the default of 'probability' values "
"on edges instead. For more information, see the "
"documentation.",
edge_info,
self.__class__.__name__,
)
edge_info_dict = nx.get_edge_attributes(self, "probability")

for (src, dst, label), attribute in edge_info_dict.items():
Expand Down Expand Up @@ -277,21 +266,20 @@ def create_figure(
:rtype: IPython.display.Image or None
"""
graph = self.dot_graph(edge_info=edge_info)
if filename is None:
logger.warning("No filename. Figure not saved.")
else:
filename, filetype = generate_filename_and_mkdir(filename)
logger.info("--- generating graph ---")
logger.info("--- writing %s file ---", filetype)
graph.write(str(filename), format=filetype)

if get_ipython() is not None:
logger.info("--- Exporting graph to notebook ---")
graph_image = Image(graph.create_png()) # pylint: disable=no-member
return graph_image
elif filename:
filename, filetype = generate_filename_and_mkdir(filename)
graph.write(str(filename), format=filetype)
else:
graph_image = None
raise RuntimeError(
"Cannot display graph in notebook. "
"Please provide a filename to save the graph to."
)

return graph_image
return None

def _trim_leaves_from_graph(self):
"""Trims all the leaves from the graph, and points each incoming
Expand Down
59 changes: 8 additions & 51 deletions src/cegpy/trees/_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from collections import defaultdict
from typing import Dict, List, Mapping, Optional, Tuple, Union
import logging
import textwrap
import numpy as np
import pydotplus as pdp
Expand All @@ -15,9 +14,6 @@
create_sampling_zeros,
)

# create logger object for this module
logger = logging.getLogger("cegpy.event_tree")


class EventTree(nx.MultiDiGraph):
"""
Expand Down Expand Up @@ -138,7 +134,6 @@ def __init__(
)

self._construct_event_tree()
logger.info("Initialisation complete!")

@property
def root(self) -> str:
Expand All @@ -155,8 +150,6 @@ def variables(self) -> List:
:rtype: List[str]
"""
variables = list(self.dataframe.columns)
logger.info("Variables extracted from dataframe were:")
logger.info(variables)
return variables

@property
Expand All @@ -169,11 +162,6 @@ def sampling_zeros(self) -> Union[List[Tuple[str]], None]:
:rtype: List[Tuple[str]] or None
"""
if self._sampling_zero_paths is None:
logger.info(
"EventTree.sampling_zero_paths \
has not been set."
)
return self._sampling_zero_paths

@sampling_zeros.setter
Expand Down Expand Up @@ -231,23 +219,8 @@ def categories_per_variable(self) -> Dict:
:rtype: Dict[str, Int]
"""

def display_nan_warning():
logger.warning(
textwrap.dedent(
""" --- NaNs found in the dataframe!! ---
cegpy assumes that NaNs are either structural zeros or
structural missing values.
Any non-structural missing values must be dealt with
prior to providing the dataset to any of the cegpy
functions. Any non-structural zeros should be explicitly
added into the cegpy objects.
--- See documentation for more information. ---"""
)
)

categories_to_ignore = {"N/A", "NA", "n/a", "na", "NAN", "nan"}
catagories_per_variable = {}
nans_filtered = False

for var in self.variables:
categories = set(self.dataframe[var].unique().tolist())
Expand All @@ -256,14 +229,9 @@ def display_nan_warning():

# remove any string nans that might have made it in.
filtered_cats = pd_filtered_categories - categories_to_ignore
if pd_filtered_categories != filtered_cats:
nans_filtered = True

catagories_per_variable[var] = len(filtered_cats)

if nans_filtered:
display_nan_warning()

return catagories_per_variable

def dot_graph(self, edge_info: str = "count") -> pdp.Dot:
Expand All @@ -283,14 +251,6 @@ def _generate_dot_graph(self, fill_colour=None, edge_info="count"):
if edge_info in self._edge_attributes:
edge_info_dict = nx.get_edge_attributes(self, edge_info)
else:
logger.warning(
"edge_info '%s' does not exist for the "
"%s class. Using the default of 'count' values "
"on edges instead. For more information, see the "
"documentation.",
edge_info,
self.__class__.__name__,
)
edge_info_dict = nx.get_edge_attributes(self, "count")

for edge, attribute in edge_info_dict.items():
Expand Down Expand Up @@ -330,20 +290,18 @@ def _create_figure(self, graph: pdp.Dot, filename: str):
and saves it to "<filename>.filetype". Supports any filetype that
graphviz supports. e.g: "event_tree.png" or "event_tree.svg" etc.
"""
if filename is None:
logger.warning("No filename. Figure not saved.")
else:
filename, filetype = generate_filename_and_mkdir(filename)
logger.info("--- generating graph ---")
logger.info("--- writing %s file ---", filetype)
graph.write(str(filename), format=filetype)
graph_image = None

if get_ipython() is not None:
logger.info("--- Exporting graph to notebook ---")
graph_image = Image(graph.create_png())
else:
elif filename:
filename, filetype = generate_filename_and_mkdir(filename)
graph.write(str(filename), format=filetype)
graph_image = None
else:
raise RuntimeError(
"Cannot display graph in notebook. "
"Please provide a filename to save the graph to."
)

return graph_image

Expand Down Expand Up @@ -457,7 +415,6 @@ def _construct_event_tree(self):
"""Constructs event_tree DiGraph.
Takes the paths, and adds all the nodes and edges to the Graph"""

logger.info("Starting construction of event tree")
self._create_path_dict_entries()
# Taking a list of a networkx graph object (self) provides a list
# of all the nodes
Expand Down
30 changes: 5 additions & 25 deletions src/cegpy/trees/_staged.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from copy import deepcopy
from fractions import Fraction
from itertools import combinations, chain
import logging
from operator import add, sub, itemgetter
from typing import Dict, List, Optional, Tuple, Union
import networkx as nx
Expand All @@ -12,8 +11,6 @@
from cegpy.trees._event import EventTree
from cegpy.utilities._util import generate_colours

logger = logging.getLogger("cegpy.staged_tree")


# pylint: disable=too-many-instance-attributes
class StagedTree(EventTree):
Expand Down Expand Up @@ -93,7 +90,6 @@ def __init__(
self._sort_count = 0
self._colours_for_situations = []
self._ahc_output = {}
logger.debug("Starting Staged Tree")

@property
def prior(self) -> Dict[Tuple[str], List[Fraction]]:
Expand Down Expand Up @@ -283,20 +279,11 @@ def _store_params(self, prior, alpha, hyperstage) -> None:
if prior:
if alpha:
self.alpha = None
logging.warning(
"Params Warning!! When prior is given, alpha is not required!"
)
self._check_prior(prior)
self.prior = prior
else:
if alpha is None:
self.alpha = self._calculate_default_alpha()
logging.warning(
"Params Warning!! Neither prior nor alpha "
"were provided. Using default alpha "
"value of %d.",
self.alpha,
)
else:
self.alpha = alpha

Expand All @@ -313,7 +300,6 @@ def _calculate_default_alpha(self) -> int:
"""If no alpha is given, a default value is calculated.
The value is calculated by determining the maximum number
of categories that any one variable has"""
logger.info("Calculating default prior")
max_count = max(list(self.categories_per_variable.values()))
return max_count

Expand All @@ -326,7 +312,6 @@ def _create_default_prior(self, alpha) -> list:
edges of a specific situation.
Indexed same as self.situations & self.egde_countset"""

logger.info("Generating default prior")
default_prior = [0] * len(self.situations)
sample_size_at_node = {}

Expand Down Expand Up @@ -372,7 +357,6 @@ def _create_default_hyperstage(self) -> list:
in the hyperstage.
The default is to allow all situations with the same number of
outgoing edges and the same edge labels to be in a common list."""
logger.info("Creating default hyperstage")
hyperstage = []
info_of_edges = []

Expand All @@ -399,7 +383,6 @@ def _create_default_hyperstage(self) -> list:
def _create_edge_countset(self) -> list:
"""Each element of list contains a list with counts along edges emanating from
a specific situation. Indexed same as self.situations"""
logger.info("Creating edge countset")
edge_countset = []

for node in self.situations:
Expand All @@ -422,7 +405,6 @@ def _calculate_sum_of_lg(array) -> float:
def _calculate_initial_loglikelihood(self, prior, posterior) -> float:
"""calculating log likelihood given a prior and posterior"""
# Calculate prior contribution
logger.info("Calculating initial loglikelihood")

pri_lg_of_sum = [self._calculate_lg_of_sum(elem) for elem in prior]
pri_sum_of_lg = [self._calculate_sum_of_lg(elem) for elem in prior]
Expand Down Expand Up @@ -690,7 +672,6 @@ def calculate_AHC_transitions(
:return: The output from the AHC algorithm, specified above.
:rtype: Dict
"""
logger.info("\n\n --- Starting AHC Algorithm ---")

self._store_params(prior, alpha, hyperstage)

Expand Down Expand Up @@ -766,16 +747,15 @@ def create_figure(
graph = self.dot_graph(edge_info)
graph_image = super()._create_figure(graph, filename)

except AttributeError:
logger.error(
except AttributeError as err:
raise RuntimeError(
"----- PLEASE RUN AHC ALGORITHM before trying to"
" export a staged tree graph -----"
)
graph_image = None
) from err

return graph_image
else:
return super().create_figure(filename, edge_info=edge_info)

return super().create_figure(filename, edge_info=edge_info)

def _apply_mean_posterior_probs(
self, merged_situations: List, mean_posterior_probs: List
Expand Down
50 changes: 7 additions & 43 deletions src/tests/test_ceg.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,14 @@ def test_stages_property(self):
def test_create_figure(self):
""".create_figure() called with no filename"""
ceg = ChainEventGraph(self.staged, generate=False)
with self.assertLogs("cegpy", level="INFO") as log_cm:
with self.assertRaisesRegex(
RuntimeError,
(
"Cannot display graph in notebook. "
"Please provide a filename to save the graph to."
),
):
assert ceg.create_figure() is None
self.assertEqual(
["WARNING:cegpy.chain_event_graph:No filename. Figure not saved."],
log_cm.output,
)


class TestCEGHelpersTestCases(unittest.TestCase):
Expand Down Expand Up @@ -556,44 +558,6 @@ def test_gen_nodes_with_increasing_distance(self) -> None:
self.assertEqual(actual_node_list.sort(), nodes.sort())


class TestEdgeInfoAttributes:
"""Test edge_info argument."""

med_s_z_paths: List[Tuple]
med_df: pd.DataFrame
med_st: StagedTree

def setup(self):
"""Test Setup"""
med_df_path = (
Path(__file__)
.resolve()
.parent.parent.joinpath("../data/medical_dm_modified.xlsx")
)
self.med_s_z_paths = None
self.med_df = pd.read_excel(med_df_path)
self.med_st = StagedTree(
dataframe=self.med_df, sampling_zero_paths=self.med_s_z_paths
)

def test_figure_with_wrong_edge_attribute(
self, caplog: pytest.LogCaptureFixture
) -> None:
"""Ensures a warning is raised when a non-existent
attribute is passed for the edge_info argument"""
msg = (
r"edge_info 'prob' does not exist for the "
r"ChainEventGraph class. Using the default of 'probability' values "
r"on edges instead. For more information, see the "
r"documentation."
)

# stratified medical dataset
ceg = ChainEventGraph(self.med_st, generate=False)
_ = ceg.create_figure(filename=None, edge_info="prob")
assert msg in caplog.text, "Expected log message not logged."


@patch.object(ChainEventGraph, "_relabel_nodes")
@patch.object(ChainEventGraph, "_gen_nodes_with_increasing_distance")
@patch.object(ChainEventGraph, "_backwards_construction")
Expand Down
Loading

0 comments on commit 41e71ac

Please sign in to comment.