Skip to content

Commit

Permalink
Issue #115 CrossBackendSplitter: add "streamed" split to allow inject…
Browse files Browse the repository at this point in the history
…ing batch job ids on the fly
  • Loading branch information
soxofaan committed Sep 6, 2023
1 parent a81f8fb commit ff18ad8
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 34 deletions.
118 changes: 85 additions & 33 deletions src/openeo_aggregator/partitionedjobs/crossbackend.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import logging
import time
from contextlib import nullcontext
from typing import Callable, Dict, List, Sequence
from typing import Callable, Dict, Iterator, List, Optional, Protocol, Sequence, Tuple

import openeo
from openeo import BatchJob
Expand All @@ -20,6 +20,42 @@

_LOAD_RESULT_PLACEHOLDER = "_placeholder:"

# Some type annotation aliases to make things more self-documenting
SubGraphId = str


class GetReplacementCallable(Protocol):
"""
Type annotation for callback functions that produce a node replacement
for a node that is split off from the main process graph
Also see `_default_get_replacement`
"""

def __call__(self, node_id: str, node: dict, sub_graph_id: SubGraphId) -> dict:
"""
:param node_id: original id of the node in the process graph (e.g. `loadcollection2`)
:param node: original node in the process graph (e.g. `{"process_id": "load_collection", "arguments": {...}}`)
:param sub_graph_id: id of the corresponding dependency subgraph
(to be handled as opaque id, but possibly something like `backend1:loadcollection2`)
:return: new process graph nodes. Should contain at least a node keyed under `node_id`
"""
...


def _default_get_replacement(node_id: str, node: dict, sub_graph_id: SubGraphId) -> dict:
"""
Default `get_replacement` function to replace a node that has been split off.
"""
return {
node_id: {
# TODO: use `load_stac` iso `load_result`
"process_id": "load_result",
"arguments": {"id": f"{_LOAD_RESULT_PLACEHOLDER}{sub_graph_id}"},
}
}


class CrossBackendSplitter(AbstractJobSplitter):
"""
Expand All @@ -42,10 +78,25 @@ def __init__(
self.backend_for_collection = backend_for_collection
self._always_split = always_split

def split(
self, process: PGWithMetadata, metadata: dict = None, job_options: dict = None
) -> PartitionedJob:
process_graph = process["process_graph"]
def split_streaming(
self,
process_graph: FlatPG,
get_replacement: GetReplacementCallable = _default_get_replacement,
) -> Iterator[Tuple[SubGraphId, SubJob, List[SubGraphId]]]:
"""
Split given process graph in sub-process graphs and return these as an iterator
in an order so that a subgraph comes after all subgraphs it depends on
(e.g. main "primary" graph comes last).
The iterator approach allows working with a dynamic `get_replacement` implementation
that adapting to on previously produced subgraphs
(e.g. creating openEO batch jobs on the fly and injecting the corresponding batch job ids appropriately).
:return: tuple containing:
- subgraph id
- SubJob
- dependencies as list of subgraph ids
"""

# Extract necessary back-ends from `load_collection` usage
backend_per_collection: Dict[str, str] = {
Expand All @@ -57,55 +108,60 @@ def split(
backend_usage = collections.Counter(backend_per_collection.values())
_log.info(f"Extracted backend usage from `load_collection` nodes: {backend_usage=} {backend_per_collection=}")

# TODO: more options to determine primary backend?
primary_backend = backend_usage.most_common(1)[0][0] if backend_usage else None
secondary_backends = {b for b in backend_usage if b != primary_backend}
_log.info(f"Backend split: {primary_backend=} {secondary_backends=}")

primary_id = "main"
primary_pg = SubJob(process_graph={}, backend_id=primary_backend)
primary_pg = {}
primary_has_load_collection = False

subjobs: Dict[str, SubJob] = {primary_id: primary_pg}
dependencies: Dict[str, List[str]] = {primary_id: []}
primary_dependencies = []

for node_id, node in process_graph.items():
if node["process_id"] == "load_collection":
bid = backend_per_collection[node["arguments"]["id"]]
if bid == primary_backend and not (
self._always_split and primary_has_load_collection
):
if bid == primary_backend and (not self._always_split or not primary_has_load_collection):
# Add to primary pg
primary_pg.process_graph[node_id] = node
primary_pg[node_id] = node
primary_has_load_collection = True
else:
# New secondary pg
pg = {
sub_id = f"{bid}:{node_id}"
sub_pg = {
node_id: node,
"sr1": {
# TODO: other/better choices for save_result format (e.g. based on backend support)?
# TODO: particular format options?
"process_id": "save_result",
"arguments": {
"data": {"from_node": node_id},
# TODO: particular format options?
# "format": "NetCDF",
"format": "GTiff",
},
"result": True,
},
}
dependency_id = f"{bid}:{node_id}"
subjobs[dependency_id] = SubJob(process_graph=pg, backend_id=bid)
dependencies[primary_id].append(dependency_id)
# Link to primary pg with load_result
primary_pg.process_graph[node_id] = {
# TODO: encapsulate this placeholder process/id better?
"process_id": "load_result",
"arguments": {
"id": f"{_LOAD_RESULT_PLACEHOLDER}{dependency_id}"
},
}

yield (sub_id, SubJob(process_graph=sub_pg, backend_id=bid), [])

# Link secondary pg into primary pg
primary_pg.update(get_replacement(node_id=node_id, node=node, sub_graph_id=sub_id))
primary_dependencies.append(sub_id)
else:
primary_pg.process_graph[node_id] = node
primary_pg[node_id] = node

yield (primary_id, SubJob(process_graph=primary_pg, backend_id=primary_backend), primary_dependencies)

def split(self, process: PGWithMetadata, metadata: dict = None, job_options: dict = None) -> PartitionedJob:
"""Split given process graph into a `PartitionedJob`"""

subjobs: Dict[SubGraphId, SubJob] = {}
dependencies: Dict[SubGraphId, List[SubGraphId]] = {}
for sub_id, subjob, sub_dependencies in self.split_streaming(process_graph=process["process_graph"]):
subjobs[sub_id] = subjob
if sub_dependencies:
dependencies[sub_id] = sub_dependencies

return PartitionedJob(
process=process,
Expand All @@ -116,9 +172,7 @@ def split(
)


def resolve_dependencies(
process_graph: FlatPG, batch_jobs: Dict[str, BatchJob]
) -> FlatPG:
def _resolve_dependencies(process_graph: FlatPG, batch_jobs: Dict[str, BatchJob]) -> FlatPG:
"""
Replace placeholders in given process graph
based on given subjob_id to batch_job_id mapping.
Expand Down Expand Up @@ -235,9 +289,7 @@ def run_partitioned_job(
# Handle job (start, poll status, ...)
if states[subjob_id] == SUBJOB_STATES.READY:
try:
process_graph = resolve_dependencies(
subjob.process_graph, batch_jobs=batch_jobs
)
process_graph = _resolve_dependencies(subjob.process_graph, batch_jobs=batch_jobs)

_log.info(
f"Starting new batch job for subjob {subjob_id!r} on backend {subjob.backend_id!r}"
Expand Down
2 changes: 1 addition & 1 deletion tests/partitionedjobs/test_crossbackend.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def test_simple(self):
res = splitter.split({"process_graph": process_graph})

assert res.subjobs == {"main": SubJob(process_graph, backend_id=None)}
assert res.dependencies == {"main": []}
assert res.dependencies == {}

def test_basic(self):
process_graph = {
Expand Down

0 comments on commit ff18ad8

Please sign in to comment.