From a9ba7e273c84b4b7d2335d10f0b2efc791a910ca Mon Sep 17 00:00:00 2001 From: Francis Charette Migneault Date: Fri, 23 Aug 2024 19:47:14 -0400 Subject: [PATCH] fixes to handle collection to files input + fix tmpdir forward after job update --- .../EchoFeatures/echo_features.cwl | 1 - tests/functional/test_wps_package.py | 73 +++++++++++++++---- weaver/datatype.py | 13 +++- weaver/formats.py | 18 +++++ .../processes/builtin/collection_processor.py | 58 +++++++++------ weaver/processes/execution.py | 37 +++++++--- weaver/store/mongodb.py | 4 +- 7 files changed, 154 insertions(+), 50 deletions(-) diff --git a/tests/functional/application-packages/EchoFeatures/echo_features.cwl b/tests/functional/application-packages/EchoFeatures/echo_features.cwl index 9edd9dbe6..67aa857fe 100644 --- a/tests/functional/application-packages/EchoFeatures/echo_features.cwl +++ b/tests/functional/application-packages/EchoFeatures/echo_features.cwl @@ -25,7 +25,6 @@ inputs: return { "type": "FeatureCollection", "features": inputs.features.every(item => item.contents) - ) }; } return inputs.features.contents; diff --git a/tests/functional/test_wps_package.py b/tests/functional/test_wps_package.py index afa684ae2..9b3f83033 100644 --- a/tests/functional/test_wps_package.py +++ b/tests/functional/test_wps_package.py @@ -2266,33 +2266,80 @@ def test_execute_job_with_bbox(self): "Expected the BBOX CRS URI to be interpreted and validated by known WPS definitions." ) - def test_execute_job_with_collection_input(self): + def test_execute_job_with_collection_input_geojson_feature_collection(self): name = "EchoFeatures" body = self.retrieve_payload(name, "deploy", local=True) proc = self.fully_qualified_test_process_name(self._testMethodName) self.deploy_process(body, describe_schema=ProcessSchema.OGC, process_id=proc) with contextlib.ExitStack() as stack: + tmp_host = "https://mocked-file-server.com" # must match collection prefix hostnames tmp_dir = stack.enter_context(tempfile.TemporaryDirectory()) # pylint: disable=R1732 - tmp_feature_collection_geojson = stack.enter_context( - tempfile.NamedTemporaryFile(suffix=".geojson", mode="w", dir=tmp_dir) # pylint: disable=R1732 - ) + stack.enter_context(mocked_file_server(tmp_dir, tmp_host, settings=self.settings, mock_browse_index=True)) + + col_file = os.path.join(tmp_dir, "test.geojson") exec_body_val = self.retrieve_payload(name, "execute", local=True) - json.dump( - exec_body_val["inputs"]["features"]["value"], - tmp_feature_collection_geojson, - ) - tmp_feature_collection_geojson.flush() - tmp_feature_collection_geojson.seek(0) + with open(col_file, mode="w", encoding="utf-8") as tmp_feature_collection_geojson: + json.dump( + exec_body_val["inputs"]["features"]["value"], + tmp_feature_collection_geojson, + ) - exec_body_col = { + col_exec_body = { "mode": ExecuteMode.ASYNC, "response": ExecuteResponse.DOCUMENT, "inputs": { "features": { - "collection": "https://mocked-file-server.com/collections/test", + # accessed directly as a static GeoJSON FeatureCollection + "collection": "https://mocked-file-server.com/test.geojson", "format": ExecuteCollectionFormat.GEOJSON, "type": ContentType.APP_GEOJSON, + }, + } + } + + for mock_exec in mocked_execute_celery(): + stack.enter_context(mock_exec) + proc_url = f"/processes/{proc}/execution" + resp = mocked_sub_requests(self.app, "post_json", proc_url, timeout=5, + data=col_exec_body, headers=self.json_headers, only_local=True) + assert resp.status_code in [200, 201], f"Failed with: [{resp.status_code}]\nReason:\n{resp.json}" + + status_url = resp.json["location"] + results = self.monitor_job(status_url) + assert "outputs" in results + + raise NotImplementedError # FIXME: implement! (see above bbox case for inspiration) + + def test_execute_job_with_collection_input_ogc_features(self): + name = "EchoFeatures" + body = self.retrieve_payload(name, "deploy", local=True) + proc = self.fully_qualified_test_process_name(self._testMethodName) + self.deploy_process(body, describe_schema=ProcessSchema.OGC, process_id=proc) + + with contextlib.ExitStack() as stack: + tmp_host = "https://mocked-file-server.com" # must match collection prefix hostnames + tmp_dir = stack.enter_context(tempfile.TemporaryDirectory()) # pylint: disable=R1732 + stack.enter_context(mocked_file_server(tmp_dir, tmp_host, settings=self.settings, mock_browse_index=True)) + + col_dir = os.path.join(tmp_dir, "collections/test") + col_file = os.path.join(col_dir, "items") + os.makedirs(col_dir) + exec_body_val = self.retrieve_payload(name, "execute", local=True) + with open(col_file, mode="w", encoding="utf-8") as tmp_feature_collection_geojson: + json.dump( + exec_body_val["inputs"]["features"]["value"], + tmp_feature_collection_geojson, + ) + + col_exec_body = { + "mode": ExecuteMode.ASYNC, + "response": ExecuteResponse.DOCUMENT, + "inputs": { + "features": { + "collection": "https://mocked-file-server.com/collections/test", + "format": ExecuteCollectionFormat.OGC_FEATURES, + "type": ContentType.APP_GEOJSON, "filter-lang": "cql2-text", "filter": "properties.name = test" } @@ -2303,7 +2350,7 @@ def test_execute_job_with_collection_input(self): stack.enter_context(mock_exec) proc_url = f"/processes/{proc}/execution" resp = mocked_sub_requests(self.app, "post_json", proc_url, timeout=5, - data=exec_body_col, headers=self.json_headers, only_local=True) + data=col_exec_body, headers=self.json_headers, only_local=True) assert resp.status_code in [200, 201], f"Failed with: [{resp.status_code}]\nReason:\n{resp.json}" status_url = resp.json["location"] diff --git a/weaver/datatype.py b/weaver/datatype.py index 306f88a86..adbd712fc 100644 --- a/weaver/datatype.py +++ b/weaver/datatype.py @@ -637,9 +637,18 @@ def __init__(self, *args, **kwargs): raise TypeError(f"Type 'str' or 'UUID' is required for '{self.__name__}.id'") self["__tmpdir"] = None + def update_from(self, job): + # type: (Job) -> None + """ + Forwards any internal or control properties from the specified :class:`Job` to this one. + """ + self["__tmpdir"] = job.get("__tmpdir") + def cleanup(self): - if self["__tmpdir"] and os.path.isdir(self["__tmpdir"]): - shutil.rmtree(self["__tmpdir"], ignore_errors=True) + # type: () -> None + _tmpdir = self.get("__tmpdir") + if isinstance(_tmpdir, str) and os.path.isdir(_tmpdir): + shutil.rmtree(_tmpdir, ignore_errors=True) @property def tmpdir(self): diff --git a/weaver/formats.py b/weaver/formats.py index 3b682d7e7..a216ca4b6 100644 --- a/weaver/formats.py +++ b/weaver/formats.py @@ -747,6 +747,24 @@ def add_content_type_charset(content_type, charset): return content_type +@overload +def get_cwl_file_format(media_type): + # type: (str) -> Tuple[Optional[JSON], Optional[str]] + ... + + +@overload +def get_cwl_file_format(media_type, make_reference=False): + # type: (str, Literal[True]) -> Tuple[Optional[JSON], Optional[str]] + ... + + +@overload +def get_cwl_file_format(media_type, make_reference=False): + # type: (str, Literal[False]) -> Optional[str] + ... + + @functools.cache def get_cwl_file_format(media_type, make_reference=False, must_exist=True, allow_synonym=True): # pylint: disable=R1260 # type: (str, bool, bool, bool) -> Union[Tuple[Optional[JSON], Optional[str]], Optional[str]] diff --git a/weaver/processes/builtin/collection_processor.py b/weaver/processes/builtin/collection_processor.py index 15234c6a6..3b04f8788 100644 --- a/weaver/processes/builtin/collection_processor.py +++ b/weaver/processes/builtin/collection_processor.py @@ -40,9 +40,12 @@ from weaver.wps_restapi import swagger_definitions as sd # isort:skip # noqa: E402 if TYPE_CHECKING: + from typing import List + from pystac import Asset from weaver.typedefs import ( + CWL_IO_FileValue, CWL_IO_ValueMap, JSON, JobValueCollection, @@ -67,7 +70,7 @@ def process_collection(collection_input, input_definition, output_dir, logger=LOGGER): - # type: (JobValueCollection, ProcessInputOutputItem, Path, LoggerHandler) -> CWL_IO_ValueMap + # type: (JobValueCollection, ProcessInputOutputItem, Path, LoggerHandler) -> List[CWL_IO_FileValue] """ Processor of a :term:`Collection`. @@ -86,12 +89,14 @@ def process_collection(collection_input, input_definition, output_dir, logger=LO input_id = get_any_id(input_definition) logger.log( logging.INFO, - "Process [{}] Got arguments: collection_input=[{}] output_dir=[{}], for input=[{}]", + "Process [{}] Got arguments: collection_input={} output_dir=[{}], for input=[{}]", PACKAGE_NAME, Lazify(lambda: repr_json(collection_input, indent=2)), output_dir, input_id, ) + os.makedirs(output_dir, exist_ok=True) + col_input = sd.ExecuteCollectionInput().deserialize(collection_input) # type: JobValueCollection col_args = dict(col_input) col_href = col_href_valid = col_args.pop("collection") @@ -116,7 +121,8 @@ def process_collection(collection_input, input_definition, output_dir, logger=LO if col_media_type and not isinstance(col_media_type, list): col_media_type = [col_media_type] - api_url, col_id = col_href.rsplit("/collections/", 1) + col_parts = col_href.rsplit("/collections/", 1) + api_url, col_id = col_parts if len(col_parts) == 2 else (None, col_parts[0]) # convert all parameters to their corresponding name of the query utility # all OWSLib utilities use (**kwargs) allowing additional parameters that will be ignored @@ -129,24 +135,26 @@ def process_collection(collection_input, input_definition, output_dir, logger=LO logger.log(logging.INFO, "Attempting resolution of collection [{}] as format [{}]", col_href, col_fmt) resolved_files = [] if col_fmt == ExecuteCollectionFormat.GEOJSON: + # static GeoJSON FeatureCollection document col_resp = request_extra( "GET", col_href, queries=col_args, headers={"Accept": f"{ContentType.APP_GEOJSON},{ContentType.APP_JSON}"}, - timeout=10, + timeout=col_args["timeout"], retries=3, only_server_errors=False, ) - if not (col_resp.status_code == 200 and "features" in col_resp.json): + col_json = col_resp.json() + if not (col_resp.status_code == 200 and "features" in col_json): raise ValueError(f"Could not parse [{col_href}] as a GeoJSON FeatureCollection.") - for i, feat in enumerate(col_resp.json["features"]): + for i, feat in enumerate(col_json["features"]): path = os.path.join(output_dir, f"feature-{i}.geojson") with open(path, mode="w", encoding="utf-8") as file: json.dump(feat, file) - file_typ = get_cwl_file_format(ContentType.APP_GEOJSON) - file_obj = {"class": "File", "path": f"file://{path}", "format": file_typ} + _, file_fmt = get_cwl_file_format(ContentType.APP_GEOJSON) + file_obj = {"class": "File", "path": f"file://{path}", "format": file_fmt} resolved_files.append(file_obj) elif col_fmt == ExecuteCollectionFormat.STAC: @@ -166,8 +174,8 @@ def process_collection(collection_input, input_definition, output_dir, logger=LO for item in search.items(): for ctype in col_media_type: for _, asset in item.get_assets(media_type=ctype): # type: (..., Asset) - file_typ = get_cwl_file_format(ctype) - file_obj = {"class": "File", "path": asset.href, "format": file_typ} + _, file_fmt = get_cwl_file_format(ctype) + file_obj = {"class": "File", "path": asset.href, "format": file_fmt} resolved_files.append(file_obj) elif col_fmt == ExecuteCollectionFormat.OGC_FEATURES: @@ -184,15 +192,15 @@ def process_collection(collection_input, input_definition, output_dir, logger=LO if "assets" in feat and col_media_type != [ContentType.APP_GEOJSON]: for name, asset in feat["assets"].items(): # type: (str, JSON) if asset["type"] in col_media_type: - file_typ = get_cwl_file_format(asset["type"]) - file_obj = {"class": "File", "path": asset["href"], "format": file_typ} + _, file_fmt = get_cwl_file_format(asset["type"]) + file_obj = {"class": "File", "path": asset["href"], "format": file_fmt} resolved_files.append(file_obj) else: path = os.path.join(output_dir, f"feature-{i}.geojson") with open(path, mode="w", encoding="utf-8") as file: json.dump(feat, file) - file_typ = get_cwl_file_format(ContentType.APP_GEOJSON) - file_obj = {"class": "File", "path": f"file://{path}", "format": file_typ} + _, file_fmt = get_cwl_file_format(ContentType.APP_GEOJSON) + file_obj = {"class": "File", "path": f"file://{path}", "format": file_fmt} resolved_files.append(file_obj) elif col_fmt == ExecuteCollectionFormat.OGC_COVERAGE: @@ -206,8 +214,8 @@ def process_collection(collection_input, input_definition, output_dir, logger=LO with open(path, mode="wb") as file: data = cast(io.BytesIO, cov.coverage(col_id)).getbuffer() file.write(data) # type: ignore - file_typ = get_cwl_file_format(ctype) - file_obj = {"class": "File", "path": f"file://{path}", "format": file_typ} + _, file_fmt = get_cwl_file_format(ctype) + file_obj = {"class": "File", "path": f"file://{path}", "format": file_fmt} resolved_files.append(file_obj) elif col_fmt in ExecuteCollectionFormat.OGC_MAP: @@ -221,8 +229,8 @@ def process_collection(collection_input, input_definition, output_dir, logger=LO with open(path, mode="wb") as file: data = cast(io.BytesIO, maps.map(col_id)).getbuffer() file.write(data) # type: ignore - file_typ = get_cwl_file_format(ctype) - file_obj = {"class": "File", "path": f"file://{path}", "format": file_typ} + _, file_fmt = get_cwl_file_format(ctype) + file_obj = {"class": "File", "path": f"file://{path}", "format": file_fmt} resolved_files.append(file_obj) else: @@ -231,14 +239,20 @@ def process_collection(collection_input, input_definition, output_dir, logger=LO if not resolved_files: raise ValueError(f"Could not extract any data or reference from input collection [{col_href}].") - outputs = {"outputs": resolved_files} # 'outputs' must match ID used in CWL definition - logger.log(logging.INFO, "Resolved collection [{}] returned {} references.", col_href, len(resolved_files)) + logger.log(logging.INFO, "Resolved collection [{}] returned {} reference(s).", col_href, len(resolved_files)) logger.log( logging.DEBUG, - "Resolved collection [{}] files:\n", + "Resolved collection [{}] files:\n{}", col_href, Lazify(lambda: repr_json(resolved_files, indent=2)), ) + return resolved_files + + +def process_cwl(collection_input, input_definition, output_dir): + # type: (JobValueCollection, ProcessInputOutputItem, Path) -> CWL_IO_ValueMap + files = process_collection(collection_input, input_definition, output_dir) + outputs = {"outputs": files} # 'outputs' must match ID used in CWL definition with open(os.path.join(output_dir, OUTPUT_CWL_JSON), mode="w", encoding="utf-8") as file: json.dump(outputs, file) return outputs @@ -271,7 +285,7 @@ def main(*args): col_in = load_file(ns.c) LOGGER.info("Process [%s] Loading process input definition '%s'.", PACKAGE_NAME, ns.p) proc_in = load_file(ns.p) - sys.exit(process_collection(col_in, proc_in, ns.o) is not None) + sys.exit(process_cwl(col_in, proc_in, ns.o) is not None) if __name__ == "__main__": diff --git a/weaver/processes/execution.py b/weaver/processes/execution.py index e32adb232..7f8b6f20f 100644 --- a/weaver/processes/execution.py +++ b/weaver/processes/execution.py @@ -276,7 +276,9 @@ def execute_process(task, job_id, wps_url, headers=None): job.save_log(errors=errors, logger=task_logger) job = store.update_job(job) finally: - # note: + # WARNING: important to clean before re-fetching, otherwise we loose internal references needing cleanup + job.cleanup() + # NOTE: # don't update the progress and status here except for 'success' to preserve last error that was set # it is more relevant to return the latest step that worked properly to understand where it failed job = store.fetch_by_id(job.id) @@ -285,7 +287,6 @@ def execute_process(task, job_id, wps_url, headers=None): job.status = Status.DISMISSED task_success = map_status(job.status) not in JOB_STATUS_CATEGORIES[StatusCategory.FAILED] collect_statistics(task_process, settings, job, rss_start) - job.cleanup() if task_success: job.progress = JobProgress.EXECUTE_MONITOR_END job.status_message = f"Job {job.status}." @@ -422,6 +423,8 @@ def parse_wps_input_complex(input_value, input_info): schema_vars = ["reference", "$schema"] input_field = get_any_value(input_info, key=True) if isinstance(input_value, dict): + if input_field is None: + input_field = get_any_value(input_value, key=True) ctype, c_enc = parse_wps_input_format(input_value, "type", search_variations=False) if not ctype: ctype, c_enc = parse_wps_input_format(input_value) @@ -527,22 +530,34 @@ def parse_wps_inputs(wps_process, job): input_values = [input_val] input_details = [job_input] # metadata directly in definition, not nested per array value + # Pre-check collection for resolution of the referenced data. + # Because each collection input can result in either '1->1' or '1->N' file reference(s) mapping, + # resolution must be performed before iterating through input value/definitions to parse them. + # Whether sink input receiving this data can map to 1 or N is up to be validated by the execution later. + resolved_inputs = [] for input_value, input_info in zip(input_values, input_details): + if isinstance(input_info, dict): + input_info["id"] = input_id + if isinstance(input_value, dict) and "collection" in input_value: + col_path = os.path.join(job.tmpdir, "inputs", input_id) + col_files = process_collection(input_value, input_info, col_path, logger=job) + resolved_inputs.extend([ + ( + {"href": col_file["path"], "type": map_cwl_media_type(col_file["format"])}, + input_info + ) + for col_file in col_files + ]) + else: + resolved_inputs.append((input_value, input_info)) + + for input_value, input_info in resolved_inputs: # if already resolved, skip parsing # it is important to omit explicitly provided 'null', otherwise the WPS object could be misleading # for example, a 'ComplexData' with 'null' data will be auto-generated as text/plan with "null" string if input_value is None: input_data = None else: - # pre-check collection for resolution of the referenced data - if isinstance(input_value, dict) and "collection" in input_value: - col_path = os.path.join(job.tmpdir, "inputs", input_id) - col_out = process_collection(input_value, input_info, col_path, logger=job) - input_value = [ - {"href": col_file["path"], "type": map_cwl_media_type(col_file["format"])} - for col_file in col_out["outputs"] - ] - # resolve according to relevant data type parsing # value could be an embedded or remote definition if input_id in complex_inputs: diff --git a/weaver/store/mongodb.py b/weaver/store/mongodb.py index 296213a22..3b7c80cc4 100644 --- a/weaver/store/mongodb.py +++ b/weaver/store/mongodb.py @@ -884,7 +884,9 @@ def update_job(self, job): job.updated = now() result = self.collection.update_one({"id": job.id}, {"$set": job.params()}) if result.acknowledged and result.matched_count == 1: - return self.fetch_by_id(job.id) + updated_job = self.fetch_by_id(job.id) + updated_job.update_from(job) + return updated_job except Exception as ex: raise JobUpdateError(f"Error occurred during job update: [{ex!r}]") raise JobUpdateError(f"Failed to update specified job: '{job!s}'")