Skip to content

Commit

Permalink
fixes to handle collection to files input + fix tmpdir forward after …
Browse files Browse the repository at this point in the history
…job update
  • Loading branch information
fmigneault committed Aug 23, 2024
1 parent dd0c8c7 commit a9ba7e2
Show file tree
Hide file tree
Showing 7 changed files with 154 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ inputs:
return {
"type": "FeatureCollection",
"features": inputs.features.every(item => item.contents)
)
};
}
return inputs.features.contents;
Expand Down
73 changes: 60 additions & 13 deletions tests/functional/test_wps_package.py
Original file line number Diff line number Diff line change
Expand Up @@ -2266,33 +2266,80 @@ def test_execute_job_with_bbox(self):
"Expected the BBOX CRS URI to be interpreted and validated by known WPS definitions."
)

def test_execute_job_with_collection_input(self):
def test_execute_job_with_collection_input_geojson_feature_collection(self):
name = "EchoFeatures"
body = self.retrieve_payload(name, "deploy", local=True)
proc = self.fully_qualified_test_process_name(self._testMethodName)
self.deploy_process(body, describe_schema=ProcessSchema.OGC, process_id=proc)

with contextlib.ExitStack() as stack:
tmp_host = "https://mocked-file-server.com" # must match collection prefix hostnames
tmp_dir = stack.enter_context(tempfile.TemporaryDirectory()) # pylint: disable=R1732
tmp_feature_collection_geojson = stack.enter_context(
tempfile.NamedTemporaryFile(suffix=".geojson", mode="w", dir=tmp_dir) # pylint: disable=R1732
)
stack.enter_context(mocked_file_server(tmp_dir, tmp_host, settings=self.settings, mock_browse_index=True))

col_file = os.path.join(tmp_dir, "test.geojson")
exec_body_val = self.retrieve_payload(name, "execute", local=True)
json.dump(
exec_body_val["inputs"]["features"]["value"],
tmp_feature_collection_geojson,
)
tmp_feature_collection_geojson.flush()
tmp_feature_collection_geojson.seek(0)
with open(col_file, mode="w", encoding="utf-8") as tmp_feature_collection_geojson:
json.dump(
exec_body_val["inputs"]["features"]["value"],
tmp_feature_collection_geojson,
)

exec_body_col = {
col_exec_body = {
"mode": ExecuteMode.ASYNC,
"response": ExecuteResponse.DOCUMENT,
"inputs": {
"features": {
"collection": "https://mocked-file-server.com/collections/test",
# accessed directly as a static GeoJSON FeatureCollection
"collection": "https://mocked-file-server.com/test.geojson",
"format": ExecuteCollectionFormat.GEOJSON,
"type": ContentType.APP_GEOJSON,
},
}
}

for mock_exec in mocked_execute_celery():
stack.enter_context(mock_exec)
proc_url = f"/processes/{proc}/execution"
resp = mocked_sub_requests(self.app, "post_json", proc_url, timeout=5,
data=col_exec_body, headers=self.json_headers, only_local=True)
assert resp.status_code in [200, 201], f"Failed with: [{resp.status_code}]\nReason:\n{resp.json}"

status_url = resp.json["location"]
results = self.monitor_job(status_url)
assert "outputs" in results

raise NotImplementedError # FIXME: implement! (see above bbox case for inspiration)

def test_execute_job_with_collection_input_ogc_features(self):
name = "EchoFeatures"
body = self.retrieve_payload(name, "deploy", local=True)
proc = self.fully_qualified_test_process_name(self._testMethodName)
self.deploy_process(body, describe_schema=ProcessSchema.OGC, process_id=proc)

with contextlib.ExitStack() as stack:
tmp_host = "https://mocked-file-server.com" # must match collection prefix hostnames
tmp_dir = stack.enter_context(tempfile.TemporaryDirectory()) # pylint: disable=R1732
stack.enter_context(mocked_file_server(tmp_dir, tmp_host, settings=self.settings, mock_browse_index=True))

col_dir = os.path.join(tmp_dir, "collections/test")
col_file = os.path.join(col_dir, "items")
os.makedirs(col_dir)
exec_body_val = self.retrieve_payload(name, "execute", local=True)
with open(col_file, mode="w", encoding="utf-8") as tmp_feature_collection_geojson:
json.dump(
exec_body_val["inputs"]["features"]["value"],
tmp_feature_collection_geojson,
)

col_exec_body = {
"mode": ExecuteMode.ASYNC,
"response": ExecuteResponse.DOCUMENT,
"inputs": {
"features": {
"collection": "https://mocked-file-server.com/collections/test",
"format": ExecuteCollectionFormat.OGC_FEATURES,
"type": ContentType.APP_GEOJSON,
"filter-lang": "cql2-text",
"filter": "properties.name = test"
}
Expand All @@ -2303,7 +2350,7 @@ def test_execute_job_with_collection_input(self):
stack.enter_context(mock_exec)
proc_url = f"/processes/{proc}/execution"
resp = mocked_sub_requests(self.app, "post_json", proc_url, timeout=5,
data=exec_body_col, headers=self.json_headers, only_local=True)
data=col_exec_body, headers=self.json_headers, only_local=True)
assert resp.status_code in [200, 201], f"Failed with: [{resp.status_code}]\nReason:\n{resp.json}"

status_url = resp.json["location"]
Expand Down
13 changes: 11 additions & 2 deletions weaver/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,9 +637,18 @@ def __init__(self, *args, **kwargs):
raise TypeError(f"Type 'str' or 'UUID' is required for '{self.__name__}.id'")
self["__tmpdir"] = None

def update_from(self, job):
# type: (Job) -> None
"""
Forwards any internal or control properties from the specified :class:`Job` to this one.
"""
self["__tmpdir"] = job.get("__tmpdir")

def cleanup(self):
if self["__tmpdir"] and os.path.isdir(self["__tmpdir"]):
shutil.rmtree(self["__tmpdir"], ignore_errors=True)
# type: () -> None
_tmpdir = self.get("__tmpdir")
if isinstance(_tmpdir, str) and os.path.isdir(_tmpdir):
shutil.rmtree(_tmpdir, ignore_errors=True)

@property
def tmpdir(self):
Expand Down
18 changes: 18 additions & 0 deletions weaver/formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,6 +747,24 @@ def add_content_type_charset(content_type, charset):
return content_type


@overload
def get_cwl_file_format(media_type):
# type: (str) -> Tuple[Optional[JSON], Optional[str]]
...


@overload
def get_cwl_file_format(media_type, make_reference=False):
# type: (str, Literal[True]) -> Tuple[Optional[JSON], Optional[str]]
...


@overload
def get_cwl_file_format(media_type, make_reference=False):
# type: (str, Literal[False]) -> Optional[str]
...


@functools.cache
def get_cwl_file_format(media_type, make_reference=False, must_exist=True, allow_synonym=True): # pylint: disable=R1260
# type: (str, bool, bool, bool) -> Union[Tuple[Optional[JSON], Optional[str]], Optional[str]]
Expand Down
58 changes: 36 additions & 22 deletions weaver/processes/builtin/collection_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,12 @@
from weaver.wps_restapi import swagger_definitions as sd # isort:skip # noqa: E402

if TYPE_CHECKING:
from typing import List

from pystac import Asset

from weaver.typedefs import (
CWL_IO_FileValue,
CWL_IO_ValueMap,
JSON,
JobValueCollection,
Expand All @@ -67,7 +70,7 @@


def process_collection(collection_input, input_definition, output_dir, logger=LOGGER):
# type: (JobValueCollection, ProcessInputOutputItem, Path, LoggerHandler) -> CWL_IO_ValueMap
# type: (JobValueCollection, ProcessInputOutputItem, Path, LoggerHandler) -> List[CWL_IO_FileValue]
"""
Processor of a :term:`Collection`.
Expand All @@ -86,12 +89,14 @@ def process_collection(collection_input, input_definition, output_dir, logger=LO
input_id = get_any_id(input_definition)
logger.log(
logging.INFO,
"Process [{}] Got arguments: collection_input=[{}] output_dir=[{}], for input=[{}]",
"Process [{}] Got arguments: collection_input={} output_dir=[{}], for input=[{}]",
PACKAGE_NAME,
Lazify(lambda: repr_json(collection_input, indent=2)),
output_dir,
input_id,
)
os.makedirs(output_dir, exist_ok=True)

col_input = sd.ExecuteCollectionInput().deserialize(collection_input) # type: JobValueCollection
col_args = dict(col_input)
col_href = col_href_valid = col_args.pop("collection")
Expand All @@ -116,7 +121,8 @@ def process_collection(collection_input, input_definition, output_dir, logger=LO
if col_media_type and not isinstance(col_media_type, list):
col_media_type = [col_media_type]

api_url, col_id = col_href.rsplit("/collections/", 1)
col_parts = col_href.rsplit("/collections/", 1)
api_url, col_id = col_parts if len(col_parts) == 2 else (None, col_parts[0])

# convert all parameters to their corresponding name of the query utility
# all OWSLib utilities use (**kwargs) allowing additional parameters that will be ignored
Expand All @@ -129,24 +135,26 @@ def process_collection(collection_input, input_definition, output_dir, logger=LO
logger.log(logging.INFO, "Attempting resolution of collection [{}] as format [{}]", col_href, col_fmt)
resolved_files = []
if col_fmt == ExecuteCollectionFormat.GEOJSON:
# static GeoJSON FeatureCollection document
col_resp = request_extra(
"GET",
col_href,
queries=col_args,
headers={"Accept": f"{ContentType.APP_GEOJSON},{ContentType.APP_JSON}"},
timeout=10,
timeout=col_args["timeout"],
retries=3,
only_server_errors=False,
)
if not (col_resp.status_code == 200 and "features" in col_resp.json):
col_json = col_resp.json()
if not (col_resp.status_code == 200 and "features" in col_json):
raise ValueError(f"Could not parse [{col_href}] as a GeoJSON FeatureCollection.")

for i, feat in enumerate(col_resp.json["features"]):
for i, feat in enumerate(col_json["features"]):
path = os.path.join(output_dir, f"feature-{i}.geojson")
with open(path, mode="w", encoding="utf-8") as file:
json.dump(feat, file)
file_typ = get_cwl_file_format(ContentType.APP_GEOJSON)
file_obj = {"class": "File", "path": f"file://{path}", "format": file_typ}
_, file_fmt = get_cwl_file_format(ContentType.APP_GEOJSON)
file_obj = {"class": "File", "path": f"file://{path}", "format": file_fmt}
resolved_files.append(file_obj)

elif col_fmt == ExecuteCollectionFormat.STAC:
Expand All @@ -166,8 +174,8 @@ def process_collection(collection_input, input_definition, output_dir, logger=LO
for item in search.items():
for ctype in col_media_type:
for _, asset in item.get_assets(media_type=ctype): # type: (..., Asset)
file_typ = get_cwl_file_format(ctype)
file_obj = {"class": "File", "path": asset.href, "format": file_typ}
_, file_fmt = get_cwl_file_format(ctype)
file_obj = {"class": "File", "path": asset.href, "format": file_fmt}
resolved_files.append(file_obj)

elif col_fmt == ExecuteCollectionFormat.OGC_FEATURES:
Expand All @@ -184,15 +192,15 @@ def process_collection(collection_input, input_definition, output_dir, logger=LO
if "assets" in feat and col_media_type != [ContentType.APP_GEOJSON]:
for name, asset in feat["assets"].items(): # type: (str, JSON)
if asset["type"] in col_media_type:
file_typ = get_cwl_file_format(asset["type"])
file_obj = {"class": "File", "path": asset["href"], "format": file_typ}
_, file_fmt = get_cwl_file_format(asset["type"])
file_obj = {"class": "File", "path": asset["href"], "format": file_fmt}
resolved_files.append(file_obj)
else:
path = os.path.join(output_dir, f"feature-{i}.geojson")
with open(path, mode="w", encoding="utf-8") as file:
json.dump(feat, file)
file_typ = get_cwl_file_format(ContentType.APP_GEOJSON)
file_obj = {"class": "File", "path": f"file://{path}", "format": file_typ}
_, file_fmt = get_cwl_file_format(ContentType.APP_GEOJSON)
file_obj = {"class": "File", "path": f"file://{path}", "format": file_fmt}
resolved_files.append(file_obj)

elif col_fmt == ExecuteCollectionFormat.OGC_COVERAGE:
Expand All @@ -206,8 +214,8 @@ def process_collection(collection_input, input_definition, output_dir, logger=LO
with open(path, mode="wb") as file:
data = cast(io.BytesIO, cov.coverage(col_id)).getbuffer()
file.write(data) # type: ignore
file_typ = get_cwl_file_format(ctype)
file_obj = {"class": "File", "path": f"file://{path}", "format": file_typ}
_, file_fmt = get_cwl_file_format(ctype)
file_obj = {"class": "File", "path": f"file://{path}", "format": file_fmt}
resolved_files.append(file_obj)

elif col_fmt in ExecuteCollectionFormat.OGC_MAP:
Expand All @@ -221,8 +229,8 @@ def process_collection(collection_input, input_definition, output_dir, logger=LO
with open(path, mode="wb") as file:
data = cast(io.BytesIO, maps.map(col_id)).getbuffer()
file.write(data) # type: ignore
file_typ = get_cwl_file_format(ctype)
file_obj = {"class": "File", "path": f"file://{path}", "format": file_typ}
_, file_fmt = get_cwl_file_format(ctype)
file_obj = {"class": "File", "path": f"file://{path}", "format": file_fmt}
resolved_files.append(file_obj)

else:
Expand All @@ -231,14 +239,20 @@ def process_collection(collection_input, input_definition, output_dir, logger=LO
if not resolved_files:
raise ValueError(f"Could not extract any data or reference from input collection [{col_href}].")

outputs = {"outputs": resolved_files} # 'outputs' must match ID used in CWL definition
logger.log(logging.INFO, "Resolved collection [{}] returned {} references.", col_href, len(resolved_files))
logger.log(logging.INFO, "Resolved collection [{}] returned {} reference(s).", col_href, len(resolved_files))
logger.log(
logging.DEBUG,
"Resolved collection [{}] files:\n",
"Resolved collection [{}] files:\n{}",
col_href,
Lazify(lambda: repr_json(resolved_files, indent=2)),
)
return resolved_files


def process_cwl(collection_input, input_definition, output_dir):
# type: (JobValueCollection, ProcessInputOutputItem, Path) -> CWL_IO_ValueMap
files = process_collection(collection_input, input_definition, output_dir)
outputs = {"outputs": files} # 'outputs' must match ID used in CWL definition
with open(os.path.join(output_dir, OUTPUT_CWL_JSON), mode="w", encoding="utf-8") as file:
json.dump(outputs, file)
return outputs
Expand Down Expand Up @@ -271,7 +285,7 @@ def main(*args):
col_in = load_file(ns.c)
LOGGER.info("Process [%s] Loading process input definition '%s'.", PACKAGE_NAME, ns.p)
proc_in = load_file(ns.p)
sys.exit(process_collection(col_in, proc_in, ns.o) is not None)
sys.exit(process_cwl(col_in, proc_in, ns.o) is not None)


if __name__ == "__main__":
Expand Down
37 changes: 26 additions & 11 deletions weaver/processes/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,9 @@ def execute_process(task, job_id, wps_url, headers=None):
job.save_log(errors=errors, logger=task_logger)
job = store.update_job(job)
finally:
# note:
# WARNING: important to clean before re-fetching, otherwise we loose internal references needing cleanup
job.cleanup()
# NOTE:
# don't update the progress and status here except for 'success' to preserve last error that was set
# it is more relevant to return the latest step that worked properly to understand where it failed
job = store.fetch_by_id(job.id)
Expand All @@ -285,7 +287,6 @@ def execute_process(task, job_id, wps_url, headers=None):
job.status = Status.DISMISSED
task_success = map_status(job.status) not in JOB_STATUS_CATEGORIES[StatusCategory.FAILED]
collect_statistics(task_process, settings, job, rss_start)
job.cleanup()
if task_success:
job.progress = JobProgress.EXECUTE_MONITOR_END
job.status_message = f"Job {job.status}."
Expand Down Expand Up @@ -422,6 +423,8 @@ def parse_wps_input_complex(input_value, input_info):
schema_vars = ["reference", "$schema"]
input_field = get_any_value(input_info, key=True)
if isinstance(input_value, dict):
if input_field is None:
input_field = get_any_value(input_value, key=True)
ctype, c_enc = parse_wps_input_format(input_value, "type", search_variations=False)
if not ctype:
ctype, c_enc = parse_wps_input_format(input_value)
Expand Down Expand Up @@ -527,22 +530,34 @@ def parse_wps_inputs(wps_process, job):
input_values = [input_val]
input_details = [job_input] # metadata directly in definition, not nested per array value

# Pre-check collection for resolution of the referenced data.
# Because each collection input can result in either '1->1' or '1->N' file reference(s) mapping,
# resolution must be performed before iterating through input value/definitions to parse them.
# Whether sink input receiving this data can map to 1 or N is up to be validated by the execution later.
resolved_inputs = []
for input_value, input_info in zip(input_values, input_details):
if isinstance(input_info, dict):
input_info["id"] = input_id
if isinstance(input_value, dict) and "collection" in input_value:
col_path = os.path.join(job.tmpdir, "inputs", input_id)
col_files = process_collection(input_value, input_info, col_path, logger=job)
resolved_inputs.extend([
(
{"href": col_file["path"], "type": map_cwl_media_type(col_file["format"])},
input_info
)
for col_file in col_files
])
else:
resolved_inputs.append((input_value, input_info))

for input_value, input_info in resolved_inputs:
# if already resolved, skip parsing
# it is important to omit explicitly provided 'null', otherwise the WPS object could be misleading
# for example, a 'ComplexData' with 'null' data will be auto-generated as text/plan with "null" string
if input_value is None:
input_data = None
else:
# pre-check collection for resolution of the referenced data
if isinstance(input_value, dict) and "collection" in input_value:
col_path = os.path.join(job.tmpdir, "inputs", input_id)
col_out = process_collection(input_value, input_info, col_path, logger=job)
input_value = [
{"href": col_file["path"], "type": map_cwl_media_type(col_file["format"])}
for col_file in col_out["outputs"]
]

# resolve according to relevant data type parsing
# value could be an embedded or remote definition
if input_id in complex_inputs:
Expand Down
Loading

0 comments on commit a9ba7e2

Please sign in to comment.