diff --git a/src/openeo_aggregator/partitionedjobs/crossbackend.py b/src/openeo_aggregator/partitionedjobs/crossbackend.py index 36d1c169..324bea85 100644 --- a/src/openeo_aggregator/partitionedjobs/crossbackend.py +++ b/src/openeo_aggregator/partitionedjobs/crossbackend.py @@ -5,7 +5,7 @@ import logging import time from contextlib import nullcontext -from typing import Callable, Dict, List, Sequence +from typing import Callable, Dict, Iterator, List, Optional, Protocol, Sequence, Tuple import openeo from openeo import BatchJob @@ -20,6 +20,42 @@ _LOAD_RESULT_PLACEHOLDER = "_placeholder:" +# Some type annotation aliases to make things more self-documenting +SubGraphId = str + + +class GetReplacementCallable(Protocol): + """ + Type annotation for callback functions that produce a node replacement + for a node that is split off from the main process graph + + Also see `_default_get_replacement` + """ + + def __call__(self, node_id: str, node: dict, sub_graph_id: SubGraphId) -> dict: + """ + :param node_id: original id of the node in the process graph (e.g. `loadcollection2`) + :param node: original node in the process graph (e.g. `{"process_id": "load_collection", "arguments": {...}}`) + :param sub_graph_id: id of the corresponding dependency subgraph + (to be handled as opaque id, but possibly something like `backend1:loadcollection2`) + + :return: new process graph nodes. Should contain at least a node keyed under `node_id` + """ + ... + + +def _default_get_replacement(node_id: str, node: dict, sub_graph_id: SubGraphId) -> dict: + """ + Default `get_replacement` function to replace a node that has been split off. + """ + return { + node_id: { + # TODO: use `load_stac` iso `load_result` + "process_id": "load_result", + "arguments": {"id": f"{_LOAD_RESULT_PLACEHOLDER}{sub_graph_id}"}, + } + } + class CrossBackendSplitter(AbstractJobSplitter): """ @@ -42,10 +78,25 @@ def __init__( self.backend_for_collection = backend_for_collection self._always_split = always_split - def split( - self, process: PGWithMetadata, metadata: dict = None, job_options: dict = None - ) -> PartitionedJob: - process_graph = process["process_graph"] + def split_streaming( + self, + process_graph: FlatPG, + get_replacement: GetReplacementCallable = _default_get_replacement, + ) -> Iterator[Tuple[SubGraphId, SubJob, List[SubGraphId]]]: + """ + Split given process graph in sub-process graphs and return these as an iterator + in an order so that a subgraph comes after all subgraphs it depends on + (e.g. main "primary" graph comes last). + + The iterator approach allows working with a dynamic `get_replacement` implementation + that adapting to on previously produced subgraphs + (e.g. creating openEO batch jobs on the fly and injecting the corresponding batch job ids appropriately). + + :return: tuple containing: + - subgraph id + - SubJob + - dependencies as list of subgraph ids + """ # Extract necessary back-ends from `load_collection` usage backend_per_collection: Dict[str, str] = { @@ -57,55 +108,60 @@ def split( backend_usage = collections.Counter(backend_per_collection.values()) _log.info(f"Extracted backend usage from `load_collection` nodes: {backend_usage=} {backend_per_collection=}") + # TODO: more options to determine primary backend? primary_backend = backend_usage.most_common(1)[0][0] if backend_usage else None secondary_backends = {b for b in backend_usage if b != primary_backend} _log.info(f"Backend split: {primary_backend=} {secondary_backends=}") primary_id = "main" - primary_pg = SubJob(process_graph={}, backend_id=primary_backend) + primary_pg = {} primary_has_load_collection = False - - subjobs: Dict[str, SubJob] = {primary_id: primary_pg} - dependencies: Dict[str, List[str]] = {primary_id: []} + primary_dependencies = [] for node_id, node in process_graph.items(): if node["process_id"] == "load_collection": bid = backend_per_collection[node["arguments"]["id"]] - if bid == primary_backend and not ( - self._always_split and primary_has_load_collection - ): + if bid == primary_backend and (not self._always_split or not primary_has_load_collection): # Add to primary pg - primary_pg.process_graph[node_id] = node + primary_pg[node_id] = node primary_has_load_collection = True else: # New secondary pg - pg = { + sub_id = f"{bid}:{node_id}" + sub_pg = { node_id: node, "sr1": { # TODO: other/better choices for save_result format (e.g. based on backend support)? - # TODO: particular format options? "process_id": "save_result", "arguments": { "data": {"from_node": node_id}, + # TODO: particular format options? # "format": "NetCDF", "format": "GTiff", }, "result": True, }, } - dependency_id = f"{bid}:{node_id}" - subjobs[dependency_id] = SubJob(process_graph=pg, backend_id=bid) - dependencies[primary_id].append(dependency_id) - # Link to primary pg with load_result - primary_pg.process_graph[node_id] = { - # TODO: encapsulate this placeholder process/id better? - "process_id": "load_result", - "arguments": { - "id": f"{_LOAD_RESULT_PLACEHOLDER}{dependency_id}" - }, - } + + yield (sub_id, SubJob(process_graph=sub_pg, backend_id=bid), []) + + # Link secondary pg into primary pg + primary_pg.update(get_replacement(node_id=node_id, node=node, sub_graph_id=sub_id)) + primary_dependencies.append(sub_id) else: - primary_pg.process_graph[node_id] = node + primary_pg[node_id] = node + + yield (primary_id, SubJob(process_graph=primary_pg, backend_id=primary_backend), primary_dependencies) + + def split(self, process: PGWithMetadata, metadata: dict = None, job_options: dict = None) -> PartitionedJob: + """Split given process graph into a `PartitionedJob`""" + + subjobs: Dict[SubGraphId, SubJob] = {} + dependencies: Dict[SubGraphId, List[SubGraphId]] = {} + for sub_id, subjob, sub_dependencies in self.split_streaming(process_graph=process["process_graph"]): + subjobs[sub_id] = subjob + if sub_dependencies: + dependencies[sub_id] = sub_dependencies return PartitionedJob( process=process, @@ -116,9 +172,7 @@ def split( ) -def resolve_dependencies( - process_graph: FlatPG, batch_jobs: Dict[str, BatchJob] -) -> FlatPG: +def _resolve_dependencies(process_graph: FlatPG, batch_jobs: Dict[str, BatchJob]) -> FlatPG: """ Replace placeholders in given process graph based on given subjob_id to batch_job_id mapping. @@ -235,9 +289,7 @@ def run_partitioned_job( # Handle job (start, poll status, ...) if states[subjob_id] == SUBJOB_STATES.READY: try: - process_graph = resolve_dependencies( - subjob.process_graph, batch_jobs=batch_jobs - ) + process_graph = _resolve_dependencies(subjob.process_graph, batch_jobs=batch_jobs) _log.info( f"Starting new batch job for subjob {subjob_id!r} on backend {subjob.backend_id!r}" diff --git a/tests/partitionedjobs/test_crossbackend.py b/tests/partitionedjobs/test_crossbackend.py index 8d1e2c82..50263236 100644 --- a/tests/partitionedjobs/test_crossbackend.py +++ b/tests/partitionedjobs/test_crossbackend.py @@ -26,7 +26,7 @@ def test_simple(self): res = splitter.split({"process_graph": process_graph}) assert res.subjobs == {"main": SubJob(process_graph, backend_id=None)} - assert res.dependencies == {"main": []} + assert res.dependencies == {} def test_basic(self): process_graph = {