From dd0c8c792fea86b953b538b6d15cacc161fe4bd8 Mon Sep 17 00:00:00 2001 From: Francis Charette Migneault Date: Thu, 22 Aug 2024 20:38:53 -0400 Subject: [PATCH] [wip] apply resolution of collection files for complex process input --- requirements.txt | 1 + weaver/datatype.py | 36 ++++++- weaver/formats.py | 11 ++ .../processes/builtin/collection_processor.py | 101 ++++++++++++------ weaver/processes/execution.py | 20 +++- weaver/processes/utils.py | 8 +- weaver/typedefs.py | 2 +- weaver/utils.py | 13 ++- 8 files changed, 151 insertions(+), 41 deletions(-) diff --git a/requirements.txt b/requirements.txt index 50cb6f179..a203d89a4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -89,6 +89,7 @@ pyramid_celery @ git+https://github.com/crim-ca/pyramid_celery.git@5.0.0a pyramid_mako pyramid_rewrite pyramid_storage +pystac pystac_client python-box python-dateutil diff --git a/weaver/datatype.py b/weaver/datatype.py index 067e24c9a..306f88a86 100644 --- a/weaver/datatype.py +++ b/weaver/datatype.py @@ -7,7 +7,10 @@ import enum import inspect import json +import os import re +import shutil +import tempfile import traceback import uuid import warnings @@ -44,6 +47,7 @@ from weaver.store.base import StoreProcesses from weaver.utils import localize_datetime # for backward compatibility of previously saved jobs not time-locale-aware from weaver.utils import ( + LoggerHandler, VersionFormat, apply_number_with_unit, as_version_major_minor_patch, @@ -90,6 +94,7 @@ Link, Metadata, Number, + Path, Price, QuoteProcessParameters, QuoteProcessResults, @@ -616,7 +621,7 @@ def check_accessible(self, settings, ignore=True): return False -class Job(Base): +class Job(Base, LoggerHandler): """ Dictionary that contains :term:`Job` details for local :term:`Process` or remote :term:`OWS` execution. @@ -630,6 +635,24 @@ def __init__(self, *args, **kwargs): raise TypeError(f"Parameter 'task_id' is required for '{self.__name__}' creation.") if not isinstance(self.id, (str, uuid.UUID)): raise TypeError(f"Type 'str' or 'UUID' is required for '{self.__name__}.id'") + self["__tmpdir"] = None + + def cleanup(self): + if self["__tmpdir"] and os.path.isdir(self["__tmpdir"]): + shutil.rmtree(self["__tmpdir"], ignore_errors=True) + + @property + def tmpdir(self): + # type: () -> Path + """ + Optional temporary directory available for the :term:`Job` to store files needed for its operation. + + It is up to the caller to remove the contents by calling :meth:`cleanup`. + """ + _tmpdir = self.get("__tmpdir") + if not _tmpdir: + _tmpdir = self["__tmpdir"] = tempfile.mkdtemp() + return _tmpdir @staticmethod def _get_message(message, size_limit=None): @@ -654,7 +677,18 @@ def _get_err_msg(error, size_limit=None): error_msg = Job._get_message(error.text, size_limit=size_limit) return f"{error_msg} - code={error.code} - locator={error.locator}" + def log(self, level, message, *args, **kwargs): + # type: (AnyLogLevel, str, *str, **Any) -> None + """ + Provides the :class:`LoggerHandler` interface, allowing to pass the :term:`Job` directly as a logger reference. + + The same parameters as :meth:`save_log` can be provided. + """ + message = message.format(*args, **kwargs) + return self.save_log(level=level, message=message, **kwargs) + def save_log(self, + *, errors=None, # type: Optional[Union[str, Exception, WPSException, List[WPSException]]] logger=None, # type: Optional[Logger] message=None, # type: Optional[str] diff --git a/weaver/formats.py b/weaver/formats.py index 3b77cac82..3b682d7e7 100644 --- a/weaver/formats.py +++ b/weaver/formats.py @@ -1,5 +1,6 @@ import base64 import datetime +import functools import json import logging import os @@ -105,6 +106,7 @@ class ContentType(Constants): APP_ZIP = "application/zip" IMAGE_GEOTIFF = "image/tiff; subtype=geotiff" IMAGE_OGC_GEOTIFF = "image/tiff; application=geotiff" + IMAGE_COG = "image/tiff; application=geotiff; profile=cloud-optimized" IMAGE_JPEG = "image/jpeg" IMAGE_GIF = "image/gif" IMAGE_PNG = "image/png" @@ -603,6 +605,7 @@ class SchemaRole(Constants): OGC_MAPPING = { ContentType.IMAGE_GEOTIFF: "geotiff", ContentType.IMAGE_OGC_GEOTIFF: "geotiff", + ContentType.IMAGE_COG: "geotiff", ContentType.APP_NETCDF: "netcdf", } FORMAT_NAMESPACE_MAPPINGS = { @@ -623,6 +626,7 @@ class SchemaRole(Constants): FORMAT_NAMESPACES = frozenset(FORMAT_NAMESPACE_DEFINITIONS) +@functools.cache def get_allowed_extensions(): # type: () -> List[str] """ @@ -649,6 +653,7 @@ def get_allowed_extensions(): return list(base | extra) +@functools.cache def get_format(media_type, default=None): # type: (str, Optional[str]) -> Optional[Format] """ @@ -668,6 +673,7 @@ def get_format(media_type, default=None): return fmt +@functools.cache def get_extension(media_type, dot=True): # type: (str, bool) -> str """ @@ -697,6 +703,7 @@ def _handle_dot(_ext): return _handle_dot(ext) +@functools.cache def get_content_type(extension, charset=None, default=None): # type: (str, Optional[str], Optional[str]) -> Optional[str] """ @@ -721,6 +728,7 @@ def get_content_type(extension, charset=None, default=None): return add_content_type_charset(ctype, charset) +@functools.cache def add_content_type_charset(content_type, charset): # type: (Union[str, ContentType], Optional[str]) -> str """ @@ -739,6 +747,7 @@ def add_content_type_charset(content_type, charset): return content_type +@functools.cache def get_cwl_file_format(media_type, make_reference=False, must_exist=True, allow_synonym=True): # pylint: disable=R1260 # type: (str, bool, bool, bool) -> Union[Tuple[Optional[JSON], Optional[str]], Optional[str]] """ @@ -860,6 +869,7 @@ def _request_extra_various(_media_type): return None if make_reference else (None, None) +@functools.cache def map_cwl_media_type(cwl_format): # type: (Optional[str]) -> Optional[str] """ @@ -891,6 +901,7 @@ def map_cwl_media_type(cwl_format): return ctype +@functools.cache def clean_media_type_format(media_type, suffix_subtype=False, strip_parameters=False): # type: (str, bool, bool) -> Optional[str] """ diff --git a/weaver/processes/builtin/collection_processor.py b/weaver/processes/builtin/collection_processor.py index 613454b86..15234c6a6 100644 --- a/weaver/processes/builtin/collection_processor.py +++ b/weaver/processes/builtin/collection_processor.py @@ -25,19 +25,31 @@ # place weaver specific imports after sys path fixing to ensure they are found from external call # pylint: disable=C0413,wrong-import-order from weaver.execute import ExecuteCollectionFormat # isort:skip # noqa: E402 -from weaver.formats import ContentType, get_extension, find_supported_media_types # isort:skip # noqa: E402 +from weaver.formats import ( # isort:skip # noqa: E402 + ContentType, + find_supported_media_types, + get_cwl_file_format, + get_extension +) from weaver.processes.builtin.utils import ( # isort:skip # noqa: E402 get_package_details, is_geojson_url, validate_reference ) -from weaver.utils import Lazify, load_file, repr_json, request_extra # isort:skip # noqa: E402 +from weaver.utils import Lazify, get_any_id, load_file, repr_json, request_extra # isort:skip # noqa: E402 from weaver.wps_restapi import swagger_definitions as sd # isort:skip # noqa: E402 if TYPE_CHECKING: from pystac import Asset - from weaver.typedefs import JSON, JobValueCollection, ProcessInputOutputItem + from weaver.typedefs import ( + CWL_IO_ValueMap, + JSON, + JobValueCollection, + Path, + ProcessInputOutputItem + ) + from weaver.utils import LoggerHandler PACKAGE_NAME, PACKAGE_BASE, PACKAGE_MODULE = get_package_details(__file__) @@ -54,8 +66,8 @@ OUTPUT_CWL_JSON = "cwl.output.json" -def process(collection_input, input_definition, output_dir): - # type: (JobValueCollection, ProcessInputOutputItem, os.PathLike[str]) -> None +def process_collection(collection_input, input_definition, output_dir, logger=LOGGER): + # type: (JobValueCollection, ProcessInputOutputItem, Path, LoggerHandler) -> CWL_IO_ValueMap """ Processor of a :term:`Collection`. @@ -68,30 +80,34 @@ def process(collection_input, input_definition, output_dir): :param input_definition: Process input definition that indicates the target types, formats and schema to retrieve from the collection. :param output_dir: Directory to write the output (provided by the :term:`CWL` definition). + :param logger: Optional logger handler to employ. :return: Resolved data references. """ - LOGGER.info( - "Process [%s] Got arguments: collection_input=%s output_dir=%s", + input_id = get_any_id(input_definition) + logger.log( + logging.INFO, + "Process [{}] Got arguments: collection_input=[{}] output_dir=[{}], for input=[{}]", PACKAGE_NAME, Lazify(lambda: repr_json(collection_input, indent=2)), output_dir, + input_id, ) col_input = sd.ExecuteCollectionInput().deserialize(collection_input) # type: JobValueCollection col_args = dict(col_input) - col_file = col_ref = col_args.pop("collection") - if not col_ref.endswith("/"): - col_ref += "/" + col_href = col_href_valid = col_args.pop("collection") + if not col_href_valid.endswith("/"): + col_href_valid += "/" col_fmt = col_args.pop("format", None) if col_fmt not in ExecuteCollectionFormat.values(): col_fmt = ExecuteCollectionFormat.GEOJSON # static GeoJSON can be either a file-like reference or a generic server endpoint (directory-like) - if col_fmt == ExecuteCollectionFormat.GEOJSON and not is_geojson_url(col_file): - validate_reference(col_ref, is_file=False) + if col_fmt == ExecuteCollectionFormat.GEOJSON and not is_geojson_url(col_href): + validate_reference(col_href_valid, is_file=False) # otherwise, any other format involves an API endpoint interaction else: - validate_reference(col_ref, is_file=False) + validate_reference(col_href_valid, is_file=False) # find which media-types are applicable for the destination input definition col_media_type = col_args.pop("type", None) @@ -100,7 +116,7 @@ def process(collection_input, input_definition, output_dir): if col_media_type and not isinstance(col_media_type, list): col_media_type = [col_media_type] - api_url, col_id = col_ref.rsplit("/collections/", 1) + api_url, col_id = col_href.rsplit("/collections/", 1) # convert all parameters to their corresponding name of the query utility # all OWSLib utilities use (**kwargs) allowing additional parameters that will be ignored @@ -110,11 +126,12 @@ def process(collection_input, input_definition, output_dir): col_args[arg.replace("-", "_")] = col_args.pop(arg) col_args.setdefault("timeout", 10) + logger.log(logging.INFO, "Attempting resolution of collection [{}] as format [{}]", col_href, col_fmt) resolved_files = [] if col_fmt == ExecuteCollectionFormat.GEOJSON: col_resp = request_extra( "GET", - col_file, + col_href, queries=col_args, headers={"Accept": f"{ContentType.APP_GEOJSON},{ContentType.APP_JSON}"}, timeout=10, @@ -122,13 +139,15 @@ def process(collection_input, input_definition, output_dir): only_server_errors=False, ) if not (col_resp.status_code == 200 and "features" in col_resp.json): - raise ValueError(f"Could not parse [{col_file}] as a GeoJSON FeatureCollection!") + raise ValueError(f"Could not parse [{col_href}] as a GeoJSON FeatureCollection.") for i, feat in enumerate(col_resp.json["features"]): path = os.path.join(output_dir, f"feature-{i}.geojson") with open(path, mode="w", encoding="utf-8") as file: json.dump(feat, file) - resolved_files.append(f"file://{path}") + file_typ = get_cwl_file_format(ContentType.APP_GEOJSON) + file_obj = {"class": "File", "path": f"file://{path}", "format": file_typ} + resolved_files.append(file_obj) elif col_fmt == ExecuteCollectionFormat.STAC: known_params = set(inspect.signature(ItemSearch).parameters) @@ -147,7 +166,9 @@ def process(collection_input, input_definition, output_dir): for item in search.items(): for ctype in col_media_type: for _, asset in item.get_assets(media_type=ctype): # type: (..., Asset) - resolved_files.append(asset.href) + file_typ = get_cwl_file_format(ctype) + file_obj = {"class": "File", "path": asset.href, "format": file_typ} + resolved_files.append(file_obj) elif col_fmt == ExecuteCollectionFormat.OGC_FEATURES: if str(col_args.get("filter_lang")) == "cql2-json": @@ -163,44 +184,64 @@ def process(collection_input, input_definition, output_dir): if "assets" in feat and col_media_type != [ContentType.APP_GEOJSON]: for name, asset in feat["assets"].items(): # type: (str, JSON) if asset["type"] in col_media_type: - resolved_files.append(asset["href"]) + file_typ = get_cwl_file_format(asset["type"]) + file_obj = {"class": "File", "path": asset["href"], "format": file_typ} + resolved_files.append(file_obj) else: path = os.path.join(output_dir, f"feature-{i}.geojson") with open(path, mode="w", encoding="utf-8") as file: json.dump(feat, file) - resolved_files.append(f"file://{path}") + file_typ = get_cwl_file_format(ContentType.APP_GEOJSON) + file_obj = {"class": "File", "path": f"file://{path}", "format": file_typ} + resolved_files.append(file_obj) elif col_fmt == ExecuteCollectionFormat.OGC_COVERAGE: cov = Coverages( url=api_url, # FIXME: add 'auth' or 'headers'? ) - ctype = col_media_type or [ContentType.IMAGE_GEOTIFF] - ext = get_extension(ctype[0], dot=False) + ctype = (col_media_type or [ContentType.IMAGE_COG])[0] + ext = get_extension(ctype, dot=False) path = os.path.join(output_dir, f"map.{ext}") with open(path, mode="wb") as file: data = cast(io.BytesIO, cov.coverage(col_id)).getbuffer() file.write(data) # type: ignore - resolved_files.append(path) + file_typ = get_cwl_file_format(ctype) + file_obj = {"class": "File", "path": f"file://{path}", "format": file_typ} + resolved_files.append(file_obj) elif col_fmt in ExecuteCollectionFormat.OGC_MAP: maps = Maps( url=api_url, # FIXME: add 'auth' or 'headers'? ) - ctype = col_media_type or [ContentType.IMAGE_GEOTIFF] + ctype = (col_media_type or [ContentType.IMAGE_COG])[0] ext = get_extension(ctype[0], dot=False) path = os.path.join(output_dir, f"map.{ext}") with open(path, mode="wb") as file: data = cast(io.BytesIO, maps.map(col_id)).getbuffer() file.write(data) # type: ignore - resolved_files.append(path) + file_typ = get_cwl_file_format(ctype) + file_obj = {"class": "File", "path": f"file://{path}", "format": file_typ} + resolved_files.append(file_obj) - outputs = { - "outputs": [{"class": "File", "location": path} for path in resolved_files], - } + else: + raise ValueError(f"Collection [{col_href}] could not be resolved. Unknown format [{col_fmt}].") + + if not resolved_files: + raise ValueError(f"Could not extract any data or reference from input collection [{col_href}].") + + outputs = {"outputs": resolved_files} # 'outputs' must match ID used in CWL definition + logger.log(logging.INFO, "Resolved collection [{}] returned {} references.", col_href, len(resolved_files)) + logger.log( + logging.DEBUG, + "Resolved collection [{}] files:\n", + col_href, + Lazify(lambda: repr_json(resolved_files, indent=2)), + ) with open(os.path.join(output_dir, OUTPUT_CWL_JSON), mode="w", encoding="utf-8") as file: - return json.dump(outputs, file) + json.dump(outputs, file) + return outputs def main(*args): @@ -230,7 +271,7 @@ def main(*args): col_in = load_file(ns.c) LOGGER.info("Process [%s] Loading process input definition '%s'.", PACKAGE_NAME, ns.p) proc_in = load_file(ns.p) - sys.exit(process(col_in, proc_in, ns.o)) + sys.exit(process_collection(col_in, proc_in, ns.o) is not None) if __name__ == "__main__": diff --git a/weaver/processes/execution.py b/weaver/processes/execution.py index d89aae657..e32adb232 100644 --- a/weaver/processes/execution.py +++ b/weaver/processes/execution.py @@ -16,10 +16,11 @@ from weaver.database import get_db from weaver.datatype import Process, Service from weaver.execute import ExecuteControlOption, ExecuteMode -from weaver.formats import AcceptLanguage, ContentType, clean_media_type_format, repr_json +from weaver.formats import AcceptLanguage, ContentType, clean_media_type_format, map_cwl_media_type, repr_json from weaver.notify import map_job_subscribers, notify_job_subscribers from weaver.owsexceptions import OWSInvalidParameterValue, OWSNoApplicableCode from weaver.processes import wps_package +from weaver.processes.builtin.collection_processor import process_collection from weaver.processes.constants import WPS_BOUNDINGBOX_DATA, WPS_COMPLEX_DATA, JobInputsOutputsSchema from weaver.processes.convert import ( convert_input_values_schema, @@ -284,6 +285,7 @@ def execute_process(task, job_id, wps_url, headers=None): job.status = Status.DISMISSED task_success = map_status(job.status) not in JOB_STATUS_CATEGORIES[StatusCategory.FAILED] collect_statistics(task_process, settings, job, rss_start) + job.cleanup() if task_success: job.progress = JobProgress.EXECUTE_MONITOR_END job.status_message = f"Job {job.status}." @@ -485,7 +487,11 @@ def parse_wps_input_literal(input_value): def parse_wps_inputs(wps_process, job): # type: (ProcessOWS, Job) -> List[Tuple[str, OWS_Input_Type]] """ - Parses expected WPS process inputs against submitted job input values considering supported process definitions. + Parses expected :term:`WPS` process inputs against submitted job input values considering supported definitions. + + According to the structure of the job inputs, and notably their key arguments, perform the relevant parsing and + data retrieval to prepare inputs in a native format that can be understood and employed by a :term:`WPS` worker + (i.e.: :class:`weaver.wps.service.WorkerService` and its underlying :mod:`pywps` implementation). """ complex_inputs = {} # type: Dict[str, ComplexInput] bbox_inputs = {} # type: Dict[str, BoundingBoxInput] @@ -528,6 +534,15 @@ def parse_wps_inputs(wps_process, job): if input_value is None: input_data = None else: + # pre-check collection for resolution of the referenced data + if isinstance(input_value, dict) and "collection" in input_value: + col_path = os.path.join(job.tmpdir, "inputs", input_id) + col_out = process_collection(input_value, input_info, col_path, logger=job) + input_value = [ + {"href": col_file["path"], "type": map_cwl_media_type(col_file["format"])} + for col_file in col_out["outputs"] + ] + # resolve according to relevant data type parsing # value could be an embedded or remote definition if input_id in complex_inputs: @@ -536,6 +551,7 @@ def parse_wps_inputs(wps_process, job): input_data = parse_wps_input_bbox(input_value, input_info) else: input_data = parse_wps_input_literal(input_value) + # re-validate the resolved data as applicable if input_data is None: job.save_log( diff --git a/weaver/processes/utils.py b/weaver/processes/utils.py index b97c1cdd9..9450f6a3d 100644 --- a/weaver/processes/utils.py +++ b/weaver/processes/utils.py @@ -74,10 +74,11 @@ LOGGER = logging.getLogger(__name__) if TYPE_CHECKING: - from typing import Any, List, Optional, Protocol, Tuple, Union + from typing import Any, List, Optional, Tuple, Union from docker.client import DockerClient + from weaver.utils import LoggerHandler from weaver.typedefs import ( AnyHeadersContainer, AnyRegistryContainer, @@ -96,11 +97,6 @@ TypedDict ) - class LoggerHandler(Protocol): - def log(self, level, message, *args, **kwargs): - # type: (int, str, *str, **Any) -> None - ... - UpdateFieldListMethod = Literal["append", "override"] UpdateFieldListSpec = TypedDict("UpdateFieldListSpec", { "source": str, diff --git a/weaver/typedefs.py b/weaver/typedefs.py index bb0ecad4a..0af6a3c4c 100644 --- a/weaver/typedefs.py +++ b/weaver/typedefs.py @@ -71,7 +71,7 @@ from weaver.status import AnyStatusType, StatusType from weaver.visibility import AnyVisibility - Path = Union[os.PathLike, str, bytes] + Path = Union[os.PathLike[str], str, bytes] Default = TypeVar("Default") # used for return value that is employed from a provided default value Params = ParamSpec("Params") # use with 'Callable[Params, Return]', 'Params.args' and 'Params.kwargs' diff --git a/weaver/utils.py b/weaver/utils.py index f2ccc13e0..fea71160d 100644 --- a/weaver/utils.py +++ b/weaver/utils.py @@ -20,7 +20,7 @@ from copy import deepcopy from datetime import datetime from pkgutil import get_loader -from typing import TYPE_CHECKING, overload +from typing import TYPE_CHECKING, Protocol, overload from urllib.parse import ParseResult, parse_qsl, unquote, urlparse, urlunsplit import boto3 @@ -199,6 +199,17 @@ class ExtendedClass(OriginalClass, ExtenderMixin): LOGGER = logging.getLogger(__name__) + +class LoggerHandler(Protocol): + """ + Minimalistic logger interface (typically :class:`logging.Logger`) intended to be used only with ``log`` method. + """ + + def log(self, level, message, *args, **kwargs): + # type: (int, str, *Any, **Any) -> None + ... + + SUPPORTED_FILE_SCHEMES = frozenset([ "file", "http",