Skip to content

Commit

Permalink
Issue #150 improve DeepGraphSplitter test coverage
Browse files Browse the repository at this point in the history
also make graph walking more deterministic (e.g. to simplify test asserts)
  • Loading branch information
soxofaan committed Sep 18, 2024
1 parent a1a9b64 commit 2cc7c86
Show file tree
Hide file tree
Showing 2 changed files with 168 additions and 72 deletions.
26 changes: 19 additions & 7 deletions src/openeo_aggregator/partitionedjobs/crossbackend.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,7 @@ class _FrozenNode:

# TODO: instead of frozen dataclass: have __init__ with some type casting/validation. Or use attrs?
# TODO: better name for this class?
# TODO: use NamedTuple instead of dataclass?

# Node ids of other nodes this node depends on (aka parents)
depends_on: frozenset[NodeId]
Expand Down Expand Up @@ -560,26 +561,36 @@ def iter_nodes(self) -> Iterator[Tuple[NodeId, _FrozenNode]]:
yield from self._graph.items()

def _walk(
self, seeds: Iterable[NodeId], next_nodes: Callable[[NodeId], Iterable[NodeId]], include_seeds: bool = True
self,
seeds: Iterable[NodeId],
next_nodes: Callable[[NodeId], Iterable[NodeId]],
include_seeds: bool = True,
auto_sort: bool = True,
) -> Iterator[NodeId]:
"""
Walk the graph nodes starting from given seed nodes, taking steps as defined by `next_nodes` function.
Optionally include seeds or not, and walk breadth first.
"""
if auto_sort:
# Automatically sort next nodes to make walk more deterministic
prepare = sorted
else:
prepare = lambda x: x

if include_seeds:
visited = set()
to_visit = list(seeds)
to_visit = list(prepare(seeds))
else:
visited = set(seeds)
to_visit = [n for s in seeds for n in next_nodes(s)]
to_visit = [n for s in seeds for n in prepare(next_nodes(s))]

while to_visit:
node_id = to_visit.pop(0)
if node_id in visited:
continue
yield node_id
visited.add(node_id)
to_visit.extend(set(next_nodes(node_id)).difference(visited))
to_visit.extend(prepare(set(next_nodes(node_id)).difference(visited)))

def walk_upstream_nodes(self, seeds: Iterable[NodeId], include_seeds: bool = True) -> Iterator[NodeId]:
"""
Expand Down Expand Up @@ -728,7 +739,7 @@ def produce_split_locations(self, limit: int = 2) -> Iterator[List[NodeId]]:
# Sort forsaken nodes (based on forsaken parent count), to start higher up the graph
# TODO: avoid need for this sort, and just use a better scoring metric higher up?
forsaken_nodes = sorted(
forsaken_nodes, reverse=True, key=lambda n: sum(p in forsaken_nodes for p in self.node(n).depends_on)
forsaken_nodes, key=lambda n: sum(p in forsaken_nodes for p in self.node(n).depends_on)
)
# Collect nodes where we could split the graph in disjoint subgraphs
articulation_points: Set[NodeId] = set(self.find_articulation_points())
Expand All @@ -745,16 +756,17 @@ def produce_split_locations(self, limit: int = 2) -> Iterator[List[NodeId]]:
raise GraphSplitException("No split options found.")
# TODO: how to handle limit? will it scale feasibly to iterate over all possibilities at this point?
# TODO: smarter picking of split node (e.g. one with most upstream nodes)
assert limit > 0
for split_node_id in split_options[:limit]:
# Split graph at this articulation point
up, down = self.split_at(split_node_id)
if down.find_forsaken_nodes():
down_splits = list(down.produce_split_locations(limit=limit - 1))
down_splits = list(down.produce_split_locations(limit=max(limit - 1, 1)))
else:
down_splits = [[]]
if up.find_forsaken_nodes():
# TODO: will this actually happen? the upstream sub-graph should be single-backend by design?
up_splits = list(up.produce_split_locations(limit=limit - 1))
up_splits = list(up.produce_split_locations(limit=max(limit - 1, 1)))
else:
up_splits = [[]]

Expand Down
Loading

0 comments on commit 2cc7c86

Please sign in to comment.