Skip to content

Commit

Permalink
Issue #150 refactor simple crossbackend split logic for future extension
Browse files Browse the repository at this point in the history
decouple "load_collection" detection from subjob creation loop
  • Loading branch information
soxofaan committed Sep 17, 2024
1 parent 96200e8 commit 415d58e
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 46 deletions.
63 changes: 30 additions & 33 deletions src/openeo_aggregator/partitionedjobs/crossbackend.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@

# Some type annotation aliases to make things more self-documenting
SubGraphId = str
NodeId = str
BackendId = str


class GraphSplitException(Exception):
Expand Down Expand Up @@ -140,44 +142,43 @@ def split_streaming(
secondary_backends = {b for b in backend_usage if b != primary_backend}
_log.info(f"Backend split: {primary_backend=} {secondary_backends=}")

primary_id = main_subgraph_id
primary_pg = {}
primary_has_load_collection = False
primary_dependencies = []

sub_graphs: List[Tuple[NodeId, Set[NodeId], BackendId]] = []
for node_id, node in process_graph.items():
if node["process_id"] == "load_collection":
bid = backend_per_collection[node["arguments"]["id"]]
if bid == primary_backend and (not self._always_split or not primary_has_load_collection):
# Add to primary pg
primary_pg[node_id] = node
primary_has_load_collection = True
else:
# New secondary pg
sub_id = f"{bid}:{node_id}"
sub_pg = {
node_id: node,
"sr1": {
# TODO: other/better choices for save_result format (e.g. based on backend support)?
"process_id": "save_result",
"arguments": {
"data": {"from_node": node_id},
# TODO: particular format options?
# "format": "NetCDF",
"format": "GTiff",
},
"result": True,
},
}
sub_graphs.append((node_id, {node_id}, bid))

yield (sub_id, SubJob(process_graph=sub_pg, backend_id=bid), [])
primary_graph_node_ids = set(process_graph.keys()).difference(n for _, ns, _ in sub_graphs for n in ns)
primary_pg = {k: process_graph[k] for k in primary_graph_node_ids}
primary_dependencies = []

# Link secondary pg into primary pg
primary_pg.update(get_replacement(node_id=node_id, node=node, subgraph_id=sub_id))
primary_dependencies.append(sub_id)
else:
primary_pg[node_id] = node
for node_id, subgraph_node_ids, backend_id in sub_graphs:
# New secondary pg
sub_id = f"{backend_id}:{node_id}"
sub_pg = {k: v for k, v in process_graph.items() if k in subgraph_node_ids}
# Add new `save_result` node to the subgraphs
sub_pg["_agg_crossbackend_save_result"] = {
# TODO: other/better choices for save_result format (e.g. based on backend support, cube type)?
"process_id": "save_result",
"arguments": {
"data": {"from_node": node_id},
# TODO: particular format options?
# "format": "NetCDF",
"format": "GTiff",
},
"result": True,
}
yield (sub_id, SubJob(process_graph=sub_pg, backend_id=backend_id), [])

# Link secondary pg into primary pg
primary_pg.update(get_replacement(node_id=node_id, node=process_graph[node_id], subgraph_id=sub_id))
primary_dependencies.append(sub_id)

primary_id = main_subgraph_id
yield (primary_id, SubJob(process_graph=primary_pg, backend_id=primary_backend), primary_dependencies)

def split(self, process: PGWithMetadata, metadata: dict = None, job_options: dict = None) -> PartitionedJob:
Expand Down Expand Up @@ -381,11 +382,6 @@ def run_partitioned_job(pjob: PartitionedJob, connection: openeo.Connection, fai
}


# Type aliases to make things more self-documenting
NodeId = str
BackendId = str


@dataclasses.dataclass(frozen=True)
class _FrozenNode:
"""
Expand Down Expand Up @@ -427,6 +423,7 @@ class _FrozenGraph:
"""

# TODO: find better class name: e.g. SplitGraphView, GraphSplitUtility, GraphSplitter, ...?
# TODO: add more logging of what is happening under the hood

def __init__(self, graph: dict[NodeId, _FrozenNode]):
# Work with a read-only proxy to prevent accidental changes
Expand Down
4 changes: 2 additions & 2 deletions tests/partitionedjobs/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,7 +814,7 @@ def test_create_job_basic(self, flask_app, api100, zk_db, dummy1, requests_mock)
"backend_id": "b1",
"process_graph": {
"lc2": {"process_id": "load_collection", "arguments": {"id": "S2"}},
"sr1": {
"_agg_crossbackend_save_result": {
"process_id": "save_result",
"arguments": {"data": {"from_node": "lc2"}, "format": "GTiff"},
"result": True,
Expand Down Expand Up @@ -873,7 +873,7 @@ def test_create_job_basic(self, flask_app, api100, zk_db, dummy1, requests_mock)
assert dummy1.get_job_status(TEST_USER, expected_job_id) == "created"
assert dummy1.get_job_data(TEST_USER, expected_job_id).create["process"]["process_graph"] == {
"lc2": {"process_id": "load_collection", "arguments": {"id": "S2"}},
"sr1": {
"_agg_crossbackend_save_result": {
"process_id": "save_result",
"arguments": {"data": {"from_node": "lc2"}, "format": "GTiff"},
"result": True,
Expand Down
22 changes: 11 additions & 11 deletions tests/partitionedjobs/test_crossbackend.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def test_split_basic(self):
"cube2": {"from_node": "lc2"},
},
},
"sr1": {
"_agg_crossbackend_save_result": {
"process_id": "save_result",
"arguments": {"data": {"from_node": "mc1"}, "format": "NetCDF"},
"result": True,
Expand All @@ -77,7 +77,7 @@ def test_split_basic(self):
"cube2": {"from_node": "lc2"},
},
},
"sr1": {
"_agg_crossbackend_save_result": {
"process_id": "save_result",
"arguments": {"data": {"from_node": "mc1"}, "format": "NetCDF"},
"result": True,
Expand All @@ -91,7 +91,7 @@ def test_split_basic(self):
"process_id": "load_collection",
"arguments": {"id": "B2_FAPAR"},
},
"sr1": {
"_agg_crossbackend_save_result": {
"process_id": "save_result",
"arguments": {"data": {"from_node": "lc2"}, "format": "GTiff"},
"result": True,
Expand All @@ -113,7 +113,7 @@ def test_split_streaming_basic(self):
"cube2": {"from_node": "lc2"},
},
},
"sr1": {
"_agg_crossbackend_save_result": {
"process_id": "save_result",
"arguments": {"data": {"from_node": "mc1"}, "format": "NetCDF"},
"result": True,
Expand All @@ -132,7 +132,7 @@ def test_split_streaming_basic(self):
"process_id": "load_collection",
"arguments": {"id": "B2_FAPAR"},
},
"sr1": {
"_agg_crossbackend_save_result": {
"process_id": "save_result",
"arguments": {"data": {"from_node": "lc2"}, "format": "GTiff"},
"result": True,
Expand All @@ -152,7 +152,7 @@ def test_split_streaming_basic(self):
"process_id": "merge_cubes",
"arguments": {"cube1": {"from_node": "lc1"}, "cube2": {"from_node": "lc2"}},
},
"sr1": {
"_agg_crossbackend_save_result": {
"process_id": "save_result",
"arguments": {"data": {"from_node": "mc1"}, "format": "NetCDF"},
"result": True,
Expand Down Expand Up @@ -204,7 +204,7 @@ def get_replacement(node_id: str, node: dict, subgraph_id: SubGraphId) -> dict:
SubJob(
process_graph={
"lc2": {"process_id": "load_collection", "arguments": {"id": "B2_FAPAR"}},
"sr1": {
"_agg_crossbackend_save_result": {
"process_id": "save_result",
"arguments": {"data": {"from_node": "lc2"}, "format": "GTiff"},
"result": True,
Expand All @@ -219,7 +219,7 @@ def get_replacement(node_id: str, node: dict, subgraph_id: SubGraphId) -> dict:
SubJob(
process_graph={
"lc3": {"process_id": "load_collection", "arguments": {"id": "B3_SCL"}},
"sr1": {
"_agg_crossbackend_save_result": {
"process_id": "save_result",
"arguments": {"data": {"from_node": "lc3"}, "format": "GTiff"},
"result": True,
Expand Down Expand Up @@ -369,7 +369,7 @@ def test_basic(self, aggregator: _FakeAggregator):
"cube2": {"from_node": "lc2"},
},
},
"sr1": {
"_agg_crossbackend_save_result": {
"process_id": "save_result",
"arguments": {"data": {"from_node": "mc1"}, "format": "NetCDF"},
"result": True,
Expand Down Expand Up @@ -404,7 +404,7 @@ def test_basic(self, aggregator: _FakeAggregator):
"cube2": {"from_node": "lc2"},
},
},
"sr1": {
"_agg_crossbackend_save_result": {
"process_id": "save_result",
"arguments": {"data": {"from_node": "mc1"}, "format": "NetCDF"},
"result": True,
Expand All @@ -415,7 +415,7 @@ def test_basic(self, aggregator: _FakeAggregator):
"process_id": "load_collection",
"arguments": {"id": "B2_FAPAR"},
},
"sr1": {
"_agg_crossbackend_save_result": {
"process_id": "save_result",
"arguments": {"data": {"from_node": "lc2"}, "format": "GTiff"},
"result": True,
Expand Down

0 comments on commit 415d58e

Please sign in to comment.