Skip to content

Commit

Permalink
Hotfix/recursive get (#43)
Browse files Browse the repository at this point in the history
* Fixing a bug with recursive get_node() and get_connection()

* Adding get_connections to/from functionality

* get_connections_to/from tests are passing

* Tests with new functionality are passing!

* Reformatted with black

* Fixed a syntax error in test_node
  • Loading branch information
fletchapin authored Sep 22, 2022
1 parent 5ab9d67 commit 482086d
Show file tree
Hide file tree
Showing 15 changed files with 224 additions and 32 deletions.
115 changes: 97 additions & 18 deletions wwtp_configuration/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,15 +69,17 @@ def remove_tag(self, tag_name):
"""
del self.tags[tag_name]

def get_tag(self, tag_name):
def get_tag(self, tag_name, recurse=False):
"""Gets the Tag object associated with `tag_name`
Parameters
----------
tag_name : str
node : Node
`Node` object to be recursively searched for the tag
recurse : bool
Whether or not to get tags recursively.
Default is False, meaning that only tags involving direct children
(and this Node itself) will be returned.
Returns
------
Expand All @@ -93,10 +95,13 @@ def get_tag(self, tag_name):
for connection in self.connections.values():
if tag_name in connection.tags.keys():
tag = connection.tags[tag_name]

if hasattr(self, "nodes") and tag is None:
for node in self.nodes.values():
tag = node.get_tag(tag_name)
if recurse:
tag = node.get_tag(tag_name, recurse=True)
else:
tag = node.tags[tag_name]

if tag:
break

Expand Down Expand Up @@ -158,7 +163,7 @@ def get_connection(self, connection_name, recurse=False):
except KeyError:
if recurse:
for node in self.nodes.values():
result = node.get_connection(connection_name)
result = node.get_connection(connection_name, recurse=True)
if result:
break

Expand Down Expand Up @@ -187,9 +192,7 @@ def get_all_connections(self, recurse=False):
if recurse:
if hasattr(self, "nodes"):
for node in self.nodes.values():
connections = connections + node.get_all_connections(
recurse=recurse
)
connections = connections + node.get_all_connections(recurse=True)
return connections

def get_node(self, node_name, recurse=False):
Expand All @@ -216,7 +219,7 @@ def get_node(self, node_name, recurse=False):
except KeyError:
if recurse:
for node in self.nodes.values():
result = node.get_node(node_name)
result = node.get_node(node_name, recurse=True)
if result:
break

Expand All @@ -243,10 +246,81 @@ def get_all_nodes(self, recurse=False):
nodes = list(self.nodes.values())
if recurse:
for node in self.nodes.values():
nodes = nodes + node.get_all_nodes(recurse=recurse)
nodes = nodes + node.get_all_nodes(recurse=True)

return nodes

def get_all_connections_to(self, node):
"""Gets all connections entering the specified Node, including those
from a different level of the hierarchy with `entry_point` specified.
Paremeters
----------
node : Node
wwtp_configuration `Node` object for which we want to get connections
Returns
-------
list of Connection
List of `Connection` objects entering the specified `node`
"""
if node is None:
return []

connections = self.get_all_connections(recurse=True)
return [
connection
for connection in connections
if connection.destination == node or connection.entry_point == node
]

def get_all_connections_from(self, node):
"""Gets all connections leaving the specified Node, including those
from a different level of the hierarchy with `exit_point` specified.
Paremeters
----------
node : Node
wwtp_configuration `Node` object for which we want to get connections
Returns
-------
list of Connection
List of `Connection` objects leaving the specified `node`
"""
if node is None:
return []

connections = self.get_all_connections(recurse=True)
return [
connection
for connection in connections
if connection.source == node or connection.exit_point == node
]

def get_parent_from_tag(self, tag):
"""Gets the parent object of a `Tag` object, as long as both the tag and its
parent object are children of `self`
Parameters
----------
tag : Tag
wwtp_configuration `Node` object for which we want the parent object
Returns
-------
Node or Connection
parent object of the Tag
"""
# this logic relies on the guarantee from parse_json that
# only tags associated with connections will have a valid destination unit ID
if tag.dest_unit_id:
parent_obj = self.get_connection(tag.parent_id, recurse=True)
else:
parent_obj = self.get_node(tag.parent_id, recurse=True)

return parent_obj


class Network(Node):
"""A water utility represented as a set of connections and nodes
Expand Down Expand Up @@ -379,11 +453,13 @@ def remove_connection(self, connection_name):
"""
del self.connections[connection_name]

def get_cogen_list(self, recurse=False):
def get_list_of_type(self, desired_type, recurse=False):
"""Searches the Facility and returns a list of all Cogeneration objects
Parameters
----------
desired_type : Node or Connection
recurse : bool
Whether or not to get cogenerators recursively.
Default is False, meaning that only direct children will be returned.
Expand All @@ -395,14 +471,17 @@ def get_cogen_list(self, recurse=False):
If `recurse` is True, all children, grandchildren, etc. are returned.
If False, only direct children are returned.
"""
cogen = []
nodes = self.get_all_nodes(recurse=recurse)
desired_objs = []
if issubclass(desired_type, Node):
objs = self.get_all_nodes(recurse=recurse)
else:
objs = self.get_all_connections(recurse=recurse)

for node in nodes:
if isinstance(node, Cogeneration):
cogen.append(node)
for obj in objs:
if isinstance(obj, desired_type):
desired_objs.append(obj)

return cogen
return desired_objs


class Facility(Network):
Expand Down
Binary file modified wwtp_configuration/tests/data/all_connections.pkl
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added wwtp_configuration/tests/data/digester.pkl
Binary file not shown.
Binary file not shown.
Binary file added wwtp_configuration/tests/data/gas_to_cogen.pkl
Binary file not shown.
Binary file not shown.
Binary file not shown.
19 changes: 17 additions & 2 deletions wwtp_configuration/tests/data/node.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"nodes": ["WWTP", "PowerGrid", "RawSewagePump"],
"connections": ["ElectricityToWWTP", "SewerIntake"],
"connections": ["ElectricityToWWTP", "SewerIntake", "GasToGrid"],
"WWTP": {
"type": "Facility",
"elevation (meters)": 1,
Expand Down Expand Up @@ -101,7 +101,9 @@
"type": "Pipe",
"source": "RawSewagePump",
"destination": "WWTP",
"contents": "UntreatedSewage"
"contents": "UntreatedSewage",
"entry_point": "Digester",
"tags": {}
},
"RawSewagePump": {
"type": "Pump",
Expand All @@ -122,5 +124,18 @@
"totalized": true
}
}
},
"GasToGrid": {
"type": "Pipe",
"source": "WWTP",
"destination": "PowerGrid",
"contents": "Biogas",
"exit_point": "Digester",
"flowrate (MGD)": {
"max": null,
"min": null,
"avg": null
},
"tags": {}
}
}
Binary file added wwtp_configuration/tests/data/sewage_pump.pkl
Binary file not shown.
Binary file modified wwtp_configuration/tests/data/top_level_connections.pkl
Binary file not shown.
122 changes: 110 additions & 12 deletions wwtp_configuration/tests/test_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
import pytest
from collections import Counter
from wwtp_configuration.units import u
from wwtp_configuration.utils import Tag
from wwtp_configuration.parse_json import JSONParser
from wwtp_configuration.node import Cogeneration
from wwtp_configuration.connection import Pipe

os.chdir(os.path.dirname(os.path.abspath(__file__)))

Expand Down Expand Up @@ -32,9 +35,8 @@
)
def test_get_tag(json_path, tag_name, expected_path):
parser = JSONParser(json_path)

result = parser.initialize_network()
tag = result.get_tag(tag_name)
tag = result.get_tag(tag_name, recurse=True)

expected = None
if expected_path:
Expand Down Expand Up @@ -66,7 +68,6 @@ def test_get_tag(json_path, tag_name, expected_path):
)
def test_get_all(json_path, recurse, connection_path, node_path, tag_path):
parser = JSONParser(json_path)

result = parser.initialize_network()

with open(connection_path, "rb") as pickle_file:
Expand Down Expand Up @@ -96,29 +97,126 @@ def test_get_all(json_path, recurse, connection_path, node_path, tag_path):
)
def test_set_energy_efficiency(json_path, cogen_id, efficiency_arg, expected):
parser = JSONParser(json_path)

result = parser.initialize_network()
cogen = result.get_node(cogen_id)

assert cogen.energy_efficiency(efficiency_arg) == expected


@pytest.mark.skipif(skip_all_tests, reason="Exclude all tests")
@pytest.mark.parametrize(
"json_path, recurse, expected",
"json_path, desired_type, recurse, expected",
[
("data/node.json", None, False, "TypeError"),
("data/node.json", Cogeneration, False, []),
("data/node.json", Pipe, False, "data/get_pipe_no_recurse.pkl"),
("data/connection.json", Cogeneration, False, "data/get_cogen.pkl"),
("data/node.json", Cogeneration, True, "data/get_cogen.pkl"),
("data/node.json", Pipe, True, "data/get_pipe_recurse.pkl"),
],
)
def test_get_list_of_type(json_path, desired_type, recurse, expected):
try:
parser = JSONParser(json_path)
result = parser.initialize_network().get_list_of_type(desired_type, recurse)

if isinstance(expected, str) and os.path.isfile(expected):
with open(expected, "rb") as pickle_file:
expected = pickle.load(pickle_file)
except Exception as err:
result = type(err).__name__

assert result == expected


@pytest.mark.skipif(skip_all_tests, reason="Exclude all tests")
@pytest.mark.parametrize(
"json_path, node_id, expected",
[
("data/node.json", False, []),
("data/connection.json", False, "data/get_cogen.pkl"),
("data/node.json", True, "data/get_cogen.pkl"),
# Case 1: node does not exist
("data/node.json", "InvalidNode", []),
# Case 2: no incoming connections but node exists
("data/node.json", "RawSewagePump", []),
# Case 3: only normal connections
("data/node.json", "Cogenerator", "data/connection_to_cogen.pkl"),
# Case 4: normal connections and entry_point
("data/node.json", "Digester", "data/connection_to_digester.pkl"),
],
)
def test_get_cogen_list(json_path, recurse, expected):
def test_get_all_connections_to(json_path, node_id, expected):
parser = JSONParser(json_path)
config = parser.initialize_network()
result = config.get_all_connections_to(config.get_node(node_id, recurse=True))

result = parser.initialize_network()
if isinstance(expected, str) and os.path.isfile(expected):
with open(expected, "rb") as pickle_file:
expected = pickle.load(pickle_file)

assert result == expected


@pytest.mark.skipif(skip_all_tests, reason="Exclude all tests")
@pytest.mark.parametrize(
"json_path, node_id, expected",
[
# Case 1: node does not exist
("data/node.json", "InvalidNode", []),
# Case 2: no outgoing connections but node exists
("data/node.json", "Cogenerator", []),
# Case 3: only normal connections
("data/node.json", "RawSewagePump", "data/connection_from_sewer.pkl"),
# Case 4: normal connections and exit_point
("data/node.json", "Digester", "data/connection_from_digester.pkl"),
],
)
def test_get_all_connections_from(json_path, node_id, expected):
parser = JSONParser(json_path)
config = parser.initialize_network()
result = config.get_all_connections_from(config.get_node(node_id, recurse=True))

if isinstance(expected, str) and os.path.isfile(expected):
with open(expected, "rb") as pickle_file:
expected = pickle.load(pickle_file)

assert result == expected


@pytest.mark.skipif(skip_all_tests, reason="Exclude all tests")
@pytest.mark.parametrize(
"json_path, tag_path, expected",
[
# Case 1: tag does not exist
("data/node.json", "NonexistentTag", None),
# Case 2: tag exists at a top level connection
(
"data/node.json",
"data/top_level_connection_tag.pkl",
"data/electricty_to_wwtp.pkl",
),
# Case 3: tag exists at a lower level connection
(
"data/node.json",
"data/lower_level_connection_tag.pkl",
"data/gas_to_cogen.pkl",
),
# Case 4: tag exists at a top level node
("data/node.json", "data/top_level_node_tag.pkl", "data/sewage_pump.pkl"),
# Case 5: tag exists at a lower level node
("data/node.json", "data/lower_level_node_tag.pkl", "data/digester.pkl"),
],
)
def test_get_parent_from_tag(json_path, tag_path, expected):
if isinstance(tag_path, str) and os.path.isfile(tag_path):
with open(tag_path, "rb") as pickle_file:
tag = pickle.load(pickle_file)
else:
tag = Tag(tag_path, None, None, None, None, None)

parser = JSONParser(json_path)
config = parser.initialize_network()
result = config.get_parent_from_tag(tag)

if isinstance(expected, str) and os.path.isfile(expected):
with open(expected, "rb") as pickle_file:
expected = pickle.load(pickle_file)

assert result.get_cogen_list(recurse) == expected
assert result == expected

0 comments on commit 482086d

Please sign in to comment.