From 750341cdaefd04ac8308e0b34ab8ca8fc0f3e75e Mon Sep 17 00:00:00 2001 From: John Chilton Date: Thu, 11 Jul 2024 14:14:16 -0400 Subject: [PATCH] Workflow tool state validation plumbing. --- .../dev/tool_state_state_classes.plantuml.svg | 44 ++++ .../dev/tool_state_state_classes.plantuml.txt | 22 ++ lib/galaxy/tool_util/parameters/__init__.py | 2 + lib/galaxy/tool_util/parameters/models.py | 12 +- .../tool_util/workflow_state/__init__.py | 23 ++ lib/galaxy/tool_util/workflow_state/_types.py | 28 +++ .../tool_util/workflow_state/convert.py | 131 +++++++++++ .../tool_util/workflow_state/validation.py | 21 ++ .../workflow_state/validation_format2.py | 142 ++++++++++++ .../workflow_state/validation_native.py | 211 ++++++++++++++++++ lib/galaxy/workflow/gx_validator.py | 63 ++++++ packages/tool_util/setup.cfg | 1 + .../test_workflow_state_helpers.py | 23 ++ .../invalid/extra_attribute.gxwf.yml | 15 ++ .../workflows/invalid/missing_link.gxwf.yml | 11 + .../invalid/wrong_link_name.gxwf.yml | 13 ++ .../test_workflow_state_conversion.py | 16 ++ .../workflows/test_workflow_validation.py | 80 +++++++ .../test_workflow_validation_helpers.py | 13 ++ .../unit/workflows/valid/simple_data.gxwf.yml | 13 ++ test/unit/workflows/valid/simple_int.gxwf.yml | 13 ++ 21 files changed, 895 insertions(+), 2 deletions(-) create mode 100644 lib/galaxy/tool_util/workflow_state/__init__.py create mode 100644 lib/galaxy/tool_util/workflow_state/_types.py create mode 100644 lib/galaxy/tool_util/workflow_state/convert.py create mode 100644 lib/galaxy/tool_util/workflow_state/validation.py create mode 100644 lib/galaxy/tool_util/workflow_state/validation_format2.py create mode 100644 lib/galaxy/tool_util/workflow_state/validation_native.py create mode 100644 lib/galaxy/workflow/gx_validator.py create mode 100644 test/unit/tool_util/workflow_state/test_workflow_state_helpers.py create mode 100644 test/unit/workflows/invalid/extra_attribute.gxwf.yml create mode 100644 test/unit/workflows/invalid/missing_link.gxwf.yml create mode 100644 test/unit/workflows/invalid/wrong_link_name.gxwf.yml create mode 100644 test/unit/workflows/test_workflow_state_conversion.py create mode 100644 test/unit/workflows/test_workflow_validation.py create mode 100644 test/unit/workflows/test_workflow_validation_helpers.py create mode 100644 test/unit/workflows/valid/simple_data.gxwf.yml create mode 100644 test/unit/workflows/valid/simple_int.gxwf.yml diff --git a/doc/source/dev/tool_state_state_classes.plantuml.svg b/doc/source/dev/tool_state_state_classes.plantuml.svg index 28c7da1c9092..5874d39c6c1d 100644 --- a/doc/source/dev/tool_state_state_classes.plantuml.svg +++ b/doc/source/dev/tool_state_state_classes.plantuml.svg @@ -50,10 +50,31 @@ state_representation = "job_internal" } note bottom: Object references of the form \n{src: "hda", id: }.\n Mapping constructs expanded out.\n (Defaults are inserted?) +class TestCaseToolState { +state_representation = "test_case" ++ _to_base_model(input_models: ToolParameterBundle): Type[BaseModel] +} +note bottom: Object references of the form file name and URIs.\n Mapping constructs not allowed.\n + +class WorkflowStepToolState { +state_representation = "workflow_step" ++ _to_base_model(input_models: ToolParameterBundle): Type[BaseModel] +} +note bottom: Nearly everything optional except conditional discriminators.\n + +class WorkflowStepLinkedToolState { +state_representation = "workflow_step_linked" ++ _to_base_model(input_models: ToolParameterBundle): Type[BaseModel] +} +note bottom: Expect pre-process ``in`` dictionaries and bring in representation\n of links and defaults and validate them in model.\n + ToolState <|- - RequestToolState ToolState <|- - RequestInternalToolState ToolState <|- - RequestInternalDereferencedToolState ToolState <|- - JobInternalToolState +ToolState <|- - TestCaseToolState +ToolState <|- - WorkflowStepToolState +ToolState <|- - WorkflowStepLinkedToolState RequestToolState - RequestInternalToolState : decode > @@ -61,6 +82,7 @@ RequestInternalToolState - RequestInternalDereferencedToolState : dereference > RequestInternalDereferencedToolState o- - JobInternalToolState : expand > +WorkflowStepToolState o- - WorkflowStepLinkedToolState : preprocess_links_and_defaults > } @enduml @@ -150,10 +172,31 @@ state_representation = "job_internal" } note bottom: Object references of the form \n{src: "hda", id: }.\n Mapping constructs expanded out.\n (Defaults are inserted?) +class TestCaseToolState { +state_representation = "test_case" ++ _to_base_model(input_models: ToolParameterBundle): Type[BaseModel] +} +note bottom: Object references of the form file name and URIs.\n Mapping constructs not allowed.\n + +class WorkflowStepToolState { +state_representation = "workflow_step" ++ _to_base_model(input_models: ToolParameterBundle): Type[BaseModel] +} +note bottom: Nearly everything optional except conditional discriminators.\n + +class WorkflowStepLinkedToolState { +state_representation = "workflow_step_linked" ++ _to_base_model(input_models: ToolParameterBundle): Type[BaseModel] +} +note bottom: Expect pre-process ``in`` dictionaries and bring in representation\n of links and defaults and validate them in model.\n + ToolState <|- - RequestToolState ToolState <|- - RequestInternalToolState ToolState <|- - RequestInternalDereferencedToolState ToolState <|- - JobInternalToolState +ToolState <|- - TestCaseToolState +ToolState <|- - WorkflowStepToolState +ToolState <|- - WorkflowStepLinkedToolState RequestToolState - RequestInternalToolState : decode > @@ -161,6 +204,7 @@ RequestInternalToolState - RequestInternalDereferencedToolState : dereference > RequestInternalDereferencedToolState o- - JobInternalToolState : expand > +WorkflowStepToolState o- - WorkflowStepLinkedToolState : preprocess_links_and_defaults > } @enduml diff --git a/doc/source/dev/tool_state_state_classes.plantuml.txt b/doc/source/dev/tool_state_state_classes.plantuml.txt index 0c2c82951eb3..d37366e96f11 100644 --- a/doc/source/dev/tool_state_state_classes.plantuml.txt +++ b/doc/source/dev/tool_state_state_classes.plantuml.txt @@ -35,10 +35,31 @@ state_representation = "job_internal" } note bottom: Object references of the form \n{src: "hda", id: }.\n Mapping constructs expanded out.\n (Defaults are inserted?) +class TestCaseToolState { +state_representation = "test_case" ++ _to_base_model(input_models: ToolParameterBundle): Type[BaseModel] +} +note bottom: Object references of the form file name and URIs.\n Mapping constructs not allowed.\n + +class WorkflowStepToolState { +state_representation = "workflow_step" ++ _to_base_model(input_models: ToolParameterBundle): Type[BaseModel] +} +note bottom: Nearly everything optional except conditional discriminators.\n + +class WorkflowStepLinkedToolState { +state_representation = "workflow_step_linked" ++ _to_base_model(input_models: ToolParameterBundle): Type[BaseModel] +} +note bottom: Expect pre-process ``in`` dictionaries and bring in representation\n of links and defaults and validate them in model.\n + ToolState <|-- RequestToolState ToolState <|-- RequestInternalToolState ToolState <|-- RequestInternalDereferencedToolState ToolState <|-- JobInternalToolState +ToolState <|-- TestCaseToolState +ToolState <|-- WorkflowStepToolState +ToolState <|-- WorkflowStepLinkedToolState RequestToolState - RequestInternalToolState : decode > @@ -46,5 +67,6 @@ RequestInternalToolState - RequestInternalDereferencedToolState : dereference > RequestInternalDereferencedToolState o-- JobInternalToolState : expand > +WorkflowStepToolState o-- WorkflowStepLinkedToolState : preprocess_links_and_defaults > } @enduml \ No newline at end of file diff --git a/lib/galaxy/tool_util/parameters/__init__.py b/lib/galaxy/tool_util/parameters/__init__.py index a3b6a8b37a7f..06f593a19f49 100644 --- a/lib/galaxy/tool_util/parameters/__init__.py +++ b/lib/galaxy/tool_util/parameters/__init__.py @@ -38,6 +38,7 @@ FloatParameterModel, HiddenParameterModel, IntegerParameterModel, + is_optional, LabelValue, RawStateDict, RepeatParameterModel, @@ -119,6 +120,7 @@ "RepeatParameterModel", "RawStateDict", "ValidationFunctionT", + "is_optional", "validate_against_model", "validate_internal_job", "validate_internal_landing_request", diff --git a/lib/galaxy/tool_util/parameters/models.py b/lib/galaxy/tool_util/parameters/models.py index 677b8f639e63..7a70c4d5c245 100644 --- a/lib/galaxy/tool_util/parameters/models.py +++ b/lib/galaxy/tool_util/parameters/models.py @@ -46,7 +46,7 @@ ) from ._types import ( cast_as_type, - is_optional, + is_optional as is_python_type_optional, list_type, optional, optional_if_needed, @@ -140,7 +140,7 @@ def dynamic_model_information_from_py_type( if requires_value is None: requires_value = param_model.request_requires_value initialize = ... if requires_value else None - py_type_is_optional = is_optional(py_type) + py_type_is_optional = is_python_type_optional(py_type) validators = validators or {} if not py_type_is_optional and not requires_value: validators["not_null"] = field_validator(name)(Validators.validate_not_none) @@ -1369,6 +1369,14 @@ class ToolParameterModel(RootModel): CwlUnionParameterModel.model_rebuild() +def is_optional(tool_parameter: ToolParameterT): + if isinstance(tool_parameter, BaseGalaxyToolParameterModelDefinition): + return tool_parameter.optional + else: + # refine CWL logic in CWL branch... + return False + + class ToolParameterBundle(Protocol): """An object having a dictionary of input models (i.e. a 'Tool')""" diff --git a/lib/galaxy/tool_util/workflow_state/__init__.py b/lib/galaxy/tool_util/workflow_state/__init__.py new file mode 100644 index 000000000000..9ae61481fe1c --- /dev/null +++ b/lib/galaxy/tool_util/workflow_state/__init__.py @@ -0,0 +1,23 @@ +"""Abstractions for reasoning about tool state within Galaxy workflows. + +Like everything else in galaxy-tool-util, this package should be independent of +Galaxy's runtime. It is meant to provide utilities for reasonsing about tool state +(largely building on the abstractions in galaxy.tool_util.parameters) within the +context of workflows. +""" + +from ._types import GetToolInfo +from .convert import ( + ConversionValidationFailure, + convert_state_to_format2, + Format2State, +) +from .validation import validate_workflow + +__all__ = ( + "ConversionValidationFailure", + "convert_state_to_format2", + "GetToolInfo", + "Format2State", + "validate_workflow", +) diff --git a/lib/galaxy/tool_util/workflow_state/_types.py b/lib/galaxy/tool_util/workflow_state/_types.py new file mode 100644 index 000000000000..cf91508624b5 --- /dev/null +++ b/lib/galaxy/tool_util/workflow_state/_types.py @@ -0,0 +1,28 @@ +from typing import ( + Any, + Dict, + Optional, + Union, +) + +from typing_extensions import ( + Literal, + Protocol, +) + +from galaxy.tool_util.models import ParsedTool + +NativeWorkflowDict = Dict[str, Any] +Format2WorkflowDict = Dict[str, Any] +AnyWorkflowDict = Union[NativeWorkflowDict, Format2WorkflowDict] +WorkflowFormat = Literal["gxformat2", "native"] +NativeStepDict = Dict[str, Any] +Format2StepDict = Dict[str, Any] +NativeToolStateDict = Dict[str, Any] +Format2StateDict = Dict[str, Any] + + +class GetToolInfo(Protocol): + """An interface for fetching tool information for steps in a workflow.""" + + def get_tool_info(self, tool_id: str, tool_version: Optional[str]) -> ParsedTool: ... diff --git a/lib/galaxy/tool_util/workflow_state/convert.py b/lib/galaxy/tool_util/workflow_state/convert.py new file mode 100644 index 000000000000..37f0c8702bb8 --- /dev/null +++ b/lib/galaxy/tool_util/workflow_state/convert.py @@ -0,0 +1,131 @@ +from typing import ( + Dict, + List, + Optional, +) + +from pydantic import ( + BaseModel, + Field, +) + +from galaxy.tool_util.models import ParsedTool +from galaxy.tool_util.parameters import ToolParameterT +from ._types import ( + Format2StateDict, + GetToolInfo, + NativeStepDict, +) +from .validation_format2 import validate_step_against +from .validation_native import ( + get_parsed_tool_for_native_step, + native_tool_state, + validate_native_step_against, +) + +Format2InputsDictT = Dict[str, str] + + +class Format2State(BaseModel): + state: Format2StateDict + inputs: Format2InputsDictT = Field(alias="in") + + +class ConversionValidationFailure(Exception): + pass + + +def convert_state_to_format2(native_step_dict: NativeStepDict, get_tool_info: GetToolInfo) -> Format2State: + parsed_tool = get_parsed_tool_for_native_step(native_step_dict, get_tool_info) + return convert_state_to_format2_using(native_step_dict, parsed_tool) + + +def convert_state_to_format2_using(native_step_dict: NativeStepDict, parsed_tool: Optional[ParsedTool]) -> Format2State: + """Create a "clean" gxformat2 workflow tool state from a native workflow step. + + gxformat2 does not know about tool specifications so it cannot reason about the native + tool state attribute and just copies it as is. This native state can be pretty ugly. The purpose + of this function is to build a cleaned up state to replace the gxformat2 copied native tool_state + with that is more readable and has stronger typing by using the tool's inputs to guide + the conversion (the parsed_tool parameter). + + This method validates both the native tool state and the resulting gxformat2 tool state + so that we can be more confident the conversion doesn't corrupt the workflow. If no meta + model to validate against is supplied or if either validation fails this method throws + ConversionValidationFailure to signal the caller to just use the native tool state as is + instead of trying to convert it to a cleaner gxformat2 tool state - under the assumption + it is better to have an "ugly" workflow than a corrupted one during conversion. + """ + if parsed_tool is None: + raise ConversionValidationFailure("Could not resolve tool inputs") + try: + validate_native_step_against(native_step_dict, parsed_tool) + except Exception: + raise ConversionValidationFailure( + "Failed to validate native step - not going to convert a tool state that isn't understood" + ) + result = _convert_valid_state_to_format2(native_step_dict, parsed_tool) + print(result.dict()) + try: + validate_step_against(result.dict(), parsed_tool) + except Exception: + raise ConversionValidationFailure( + "Failed to validate resulting cleaned step - not going to convert to an unvalidated tool state" + ) + return result + + +def _convert_valid_state_to_format2(native_step_dict: NativeStepDict, parsed_tool: ParsedTool) -> Format2State: + format2_state: Format2StateDict = {} + format2_in: Format2InputsDictT = {} + + root_tool_state = native_tool_state(native_step_dict) + tool_inputs = parsed_tool.inputs + _convert_state_level(native_step_dict, tool_inputs, root_tool_state, format2_state, format2_in) + return Format2State( + **{ + "state": format2_state, + "in": format2_in, + } + ) + + +def _convert_state_level( + step: NativeStepDict, + tool_inputs: List[ToolParameterT], + native_state: dict, + format2_state_at_level: dict, + format2_in: Format2InputsDictT, + prefix: Optional[str] = None, +) -> None: + prefix = prefix or "" + assert prefix is not None + for tool_input in tool_inputs: + _convert_state_at_level(step, tool_input, native_state, format2_state_at_level, format2_in, prefix) + + +def _convert_state_at_level( + step: NativeStepDict, + tool_input: ToolParameterT, + native_state_at_level: dict, + format2_state_at_level: dict, + format2_in: Format2InputsDictT, + prefix: str, +) -> None: + parameter_type = tool_input.parameter_type + parameter_name = tool_input.name + value = native_state_at_level.get(parameter_name, None) + state_path = parameter_name if prefix is None else f"{prefix}|{parameter_name}" + if parameter_type == "gx_integer": + # check for runtime input + format2_value = int(value) + format2_state_at_level[parameter_name] = format2_value + elif parameter_type == "gx_data": + input_connections = step.get("input_connections", {}) + print(state_path) + print(input_connections) + if state_path in input_connections: + format2_in[state_path] = "placeholder" + else: + pass + # raise NotImplementedError(f"Unhandled parameter type {parameter_type}") diff --git a/lib/galaxy/tool_util/workflow_state/validation.py b/lib/galaxy/tool_util/workflow_state/validation.py new file mode 100644 index 000000000000..56a225fbb8f8 --- /dev/null +++ b/lib/galaxy/tool_util/workflow_state/validation.py @@ -0,0 +1,21 @@ +from ._types import ( + AnyWorkflowDict, + GetToolInfo, + WorkflowFormat, +) +from .validation_format2 import validate_workflow_format2 +from .validation_native import validate_workflow_native + + +def validate_workflow(workflow_dict: AnyWorkflowDict, get_tool_info: GetToolInfo): + if _format(workflow_dict) == "gxformat2": + validate_workflow_format2(workflow_dict, get_tool_info) + else: + validate_workflow_native(workflow_dict, get_tool_info) + + +def _format(workflow_dict: AnyWorkflowDict) -> WorkflowFormat: + if workflow_dict.get("a_galaxy_workflow") == "true": + return "native" + else: + return "gxformat2" diff --git a/lib/galaxy/tool_util/workflow_state/validation_format2.py b/lib/galaxy/tool_util/workflow_state/validation_format2.py new file mode 100644 index 000000000000..160576c3abe7 --- /dev/null +++ b/lib/galaxy/tool_util/workflow_state/validation_format2.py @@ -0,0 +1,142 @@ +from typing import ( + cast, + Optional, +) + +from gxformat2.model import ( + get_native_step_type, + pop_connect_from_step_dict, + setup_connected_values, + steps_as_list, +) + +from galaxy.tool_util.models import ParsedTool +from galaxy.tool_util.parameters import ( + ConditionalParameterModel, + ConditionalWhen, + flat_state_path, + keys_starting_with, + repeat_inputs_to_array, + RepeatParameterModel, + ToolParameterT, + validate_explicit_conditional_test_value, + WorkflowStepLinkedToolState, + WorkflowStepToolState, +) +from ._types import ( + Format2StepDict, + Format2WorkflowDict, + GetToolInfo, +) + + +def validate_workflow_format2(workflow_dict: Format2WorkflowDict, get_tool_info: GetToolInfo): + steps = steps_as_list(workflow_dict) + for step in steps: + validate_step_format2(step, get_tool_info) + + +def validate_step_format2(step_dict: Format2StepDict, get_tool_info: GetToolInfo): + step_type = get_native_step_type(step_dict) + if step_type != "tool": + return + tool_id = cast(str, step_dict.get("tool_id")) + tool_version: Optional[str] = cast(Optional[str], step_dict.get("tool_version")) + parsed_tool = get_tool_info.get_tool_info(tool_id, tool_version) + if parsed_tool is not None: + validate_step_against(step_dict, parsed_tool) + + +def validate_step_against(step_dict: Format2StepDict, parsed_tool: ParsedTool): + source_tool_state_model = WorkflowStepToolState.parameter_model_for(parsed_tool.inputs) + linked_tool_state_model = WorkflowStepLinkedToolState.parameter_model_for(parsed_tool.inputs) + contains_format2_state = "state" in step_dict + contains_native_state = "tool_state" in step_dict + if contains_format2_state: + assert source_tool_state_model + source_tool_state_model.model_validate(step_dict["state"]) + if not contains_native_state: + if not contains_format2_state: + step_dict["state"] = {} + # setup links and then validate against model... + linked_step = merge_inputs(step_dict, parsed_tool) + linked_tool_state_model.model_validate(linked_step["state"]) + + +def merge_inputs(step_dict: Format2StepDict, parsed_tool: ParsedTool) -> Format2StepDict: + connect = pop_connect_from_step_dict(step_dict) + step_dict = setup_connected_values(step_dict, connect) + tool_inputs = parsed_tool.inputs + + state_at_level = step_dict["state"] + + for tool_input in tool_inputs: + _merge_into_state(connect, tool_input, state_at_level) + + for key in connect: + raise Exception(f"Failed to find parameter definition matching workflow linked key {key}") + return step_dict + + +def _merge_into_state( + connect, tool_input: ToolParameterT, state: dict, prefix: Optional[str] = None, branch_connect=None +): + if branch_connect is None: + branch_connect = connect + + name = tool_input.name + parameter_type = tool_input.parameter_type + state_path = flat_state_path(name, prefix) + if parameter_type == "gx_conditional": + conditional_state = state.get(name, {}) + if name not in state: + state[name] = conditional_state + + conditional = cast(ConditionalParameterModel, tool_input) + when: ConditionalWhen = _select_which_when(conditional, conditional_state) + test_parameter = conditional.test_parameter + conditional_connect = keys_starting_with(branch_connect, state_path) + _merge_into_state( + connect, test_parameter, conditional_state, prefix=state_path, branch_connect=conditional_connect + ) + for when_parameter in when.parameters: + _merge_into_state( + connect, when_parameter, conditional_state, prefix=state_path, branch_connect=conditional_connect + ) + elif parameter_type == "gx_repeat": + repeat_state_array = state.get(name, []) + repeat = cast(RepeatParameterModel, tool_input) + repeat_instance_connects = repeat_inputs_to_array(state_path, connect) + for i, repeat_instance_connect in enumerate(repeat_instance_connects): + while len(repeat_state_array) <= i: + repeat_state_array.append({}) + + repeat_instance_prefix = f"{state_path}_{i}" + for repeat_parameter in repeat.parameters: + _merge_into_state( + connect, + repeat_parameter, + repeat_state_array[i], + prefix=repeat_instance_prefix, + branch_connect=repeat_instance_connect, + ) + if repeat_state_array and name not in state: + state[name] = repeat_state_array + else: + if state_path in branch_connect: + state[name] = {"__class__": "ConnectedValue"} + del connect[state_path] + + +def _select_which_when(conditional: ConditionalParameterModel, state: dict) -> ConditionalWhen: + test_parameter = conditional.test_parameter + test_parameter_name = test_parameter.name + explicit_test_value = state.get(test_parameter_name) + test_value = validate_explicit_conditional_test_value(test_parameter_name, explicit_test_value) + for when in conditional.whens: + if test_value is None and when.is_default_when: + return when + elif test_value == when.discriminator: + return when + else: + raise Exception(f"Invalid conditional test value ({explicit_test_value}) for parameter ({test_parameter_name})") diff --git a/lib/galaxy/tool_util/workflow_state/validation_native.py b/lib/galaxy/tool_util/workflow_state/validation_native.py new file mode 100644 index 000000000000..ae08d15b4226 --- /dev/null +++ b/lib/galaxy/tool_util/workflow_state/validation_native.py @@ -0,0 +1,211 @@ +import json +from typing import ( + cast, + List, + Optional, +) + +from galaxy.tool_util.models import ParsedTool +from galaxy.tool_util.parameters import ( + ConditionalParameterModel, + ConditionalWhen, + flat_state_path, + is_optional, + repeat_inputs_to_array, + RepeatParameterModel, + SelectParameterModel, + ToolParameterT, + validate_explicit_conditional_test_value, +) +from ._types import ( + GetToolInfo, + NativeStepDict, + NativeToolStateDict, + NativeWorkflowDict, +) + + +def validate_native_step_against(step: NativeStepDict, parsed_tool: ParsedTool): + tool_state_jsonified = step.get("tool_state") + assert tool_state_jsonified + tool_state = json.loads(tool_state_jsonified) + tool_inputs = parsed_tool.inputs + + # merge input connections into ConnectedValues if there aren't already there... + _merge_inputs_into_state_dict(step, tool_inputs, tool_state) + + allowed_extra_keys = ["__page__", "__rerun_remap_job_id__"] + _validate_native_state_level(step, tool_inputs, tool_state, allowed_extra_keys=allowed_extra_keys) + + +def _validate_native_state_level( + step: NativeStepDict, tool_inputs: List[ToolParameterT], state_at_level: dict, allowed_extra_keys=None +): + if allowed_extra_keys is None: + allowed_extra_keys = [] + + keys_processed = set() + for tool_input in tool_inputs: + parameter_name = tool_input.name + keys_processed.add(parameter_name) + _validate_native_state_at_level(step, tool_input, state_at_level) + + for key in state_at_level.keys(): + if key not in keys_processed and key not in allowed_extra_keys: + raise Exception(f"Unknown key found {key}, failing state validation") + + +def _validate_native_state_at_level( + step: NativeStepDict, tool_input: ToolParameterT, state_at_level: dict, prefix: Optional[str] = None +): + parameter_type = tool_input.parameter_type + parameter_name = tool_input.name + value = state_at_level.get(parameter_name, None) + # state_path = parameter_name if prefix is None else f"{prefix}|{parameter_name}" + if parameter_type == "gx_integer": + try: + int(value) + except ValueError: + raise Exception(f"Invalid integer data found {value}") + elif parameter_type in ["gx_data", "gx_data_collection"]: + if isinstance(value, dict): + assert "__class__" in value + assert value["__class__"] in ["RuntimeValue", "ConnectedValue"] + else: + assert value in [None, "null"] + connections = native_connections_for(step, tool_input, prefix) + optional = is_optional(tool_input) + if not optional and not connections: + raise Exception( + "Disconnected non-optional input found, not attempting to validate non-practice workflow" + ) + + elif parameter_type == "gx_select": + select = cast(SelectParameterModel, tool_input) + options = select.options + if options is not None: + valid_values = [o.value for o in options] + if value not in valid_values: + raise Exception(f"Invalid select option found {value}") + elif parameter_type == "gx_conditional": + conditional_state = state_at_level.get(parameter_name, None) + conditional = cast(ConditionalParameterModel, tool_input) + when: ConditionalWhen = _select_which_when_native(conditional, conditional_state) + test_parameter = conditional.test_parameter + test_parameter_name = test_parameter.name + _validate_native_state_at_level(step, test_parameter, conditional_state) + _validate_native_state_level( + step, when.parameters, conditional_state, allowed_extra_keys=["__current_case__", test_parameter_name] + ) + else: + raise NotImplementedError(f"Unhandled parameter type ({parameter_type})") + + +def _select_which_when_native(conditional: ConditionalParameterModel, conditional_state: dict) -> ConditionalWhen: + test_parameter = conditional.test_parameter + test_parameter_name = test_parameter.name + explicit_test_value = conditional_state.get(test_parameter_name) + test_value = validate_explicit_conditional_test_value(test_parameter_name, explicit_test_value) + target_when = None + for when in conditional.whens: + # deal with native string -> bool issues in here... + if test_value is None and when.is_default_when: + target_when = when + elif test_value == when.discriminator: + target_when = when + + recorded_case = conditional_state.get("__current_case__") + if recorded_case is not None: + if not isinstance(recorded_case, int): + raise Exception(f"Unknown type of value for __current_case__ encountered {recorded_case}") + if recorded_case < 0 or recorded_case >= len(conditional.whens): + raise Exception(f"Unknown index value for __current_case__ encountered {recorded_case}") + recorded_when = conditional.whens[recorded_case] + + if target_when is None: + raise Exception("is this okay? I need more tests") + if target_when and recorded_when and target_when != recorded_when: + raise Exception( + f"Problem parsing out tool state - inferred conflicting tool states for parameter {test_parameter_name}" + ) + return target_when + + +def _merge_inputs_into_state_dict( + step_dict: NativeStepDict, tool_inputs: List[ToolParameterT], state_at_level: dict, prefix: Optional[str] = None +): + for tool_input in tool_inputs: + _merge_into_state(step_dict, tool_input, state_at_level, prefix=prefix) + + +def _merge_into_state(step_dict: NativeStepDict, tool_input: ToolParameterT, state: dict, prefix: Optional[str] = None): + name = tool_input.name + parameter_type = tool_input.parameter_type + state_path = flat_state_path(name, prefix) + if parameter_type == "gx_conditional": + conditional_state = state.get(name, {}) + if name not in state: + state[name] = conditional_state + + conditional = cast(ConditionalParameterModel, tool_input) + when: ConditionalWhen = _select_which_when_native(conditional, conditional_state) + test_parameter = conditional.test_parameter + _merge_into_state(step_dict, test_parameter, conditional_state, prefix=state_path) + for when_parameter in when.parameters: + _merge_into_state(step_dict, when_parameter, conditional_state, prefix=state_path) + elif parameter_type == "gx_repeat": + repeat_state_array = state.get(name, []) + repeat = cast(RepeatParameterModel, tool_input) + repeat_instance_connects = repeat_inputs_to_array(state_path, step_dict) + for i, _ in enumerate(repeat_instance_connects): + while len(repeat_state_array) <= i: + repeat_state_array.append({}) + + repeat_instance_prefix = f"{state_path}_{i}" + for repeat_parameter in repeat.parameters: + _merge_into_state( + step_dict, + repeat_parameter, + repeat_state_array[i], + prefix=repeat_instance_prefix, + ) + if repeat_state_array and name not in state: + state[name] = repeat_state_array + else: + input_connections = step_dict.get("input_connections", {}) + if state_path in input_connections and state.get(name) is None: + state[name] = {"__class__": "ConnectedValue"} + + +def validate_step_native(step: NativeStepDict, get_tool_info: GetToolInfo): + parsed_tool = get_parsed_tool_for_native_step(step, get_tool_info) + if parsed_tool is not None: + validate_native_step_against(step, parsed_tool) + + +def get_parsed_tool_for_native_step(step: NativeStepDict, get_tool_info: GetToolInfo) -> Optional[ParsedTool]: + tool_id = cast(str, step.get("tool_id")) + if not tool_id: + return None + tool_version: Optional[str] = cast(Optional[str], step.get("tool_version")) + parsed_tool = get_tool_info.get_tool_info(tool_id, tool_version) + return parsed_tool + + +def validate_workflow_native(workflow_dict: NativeWorkflowDict, get_tool_info: GetToolInfo): + for step_def in workflow_dict["steps"].values(): + validate_step_native(step_def, get_tool_info) + + +def native_tool_state(step: NativeStepDict) -> NativeToolStateDict: + tool_state_jsonified = step.get("tool_state") + assert tool_state_jsonified + tool_state = json.loads(tool_state_jsonified) + return tool_state + + +def native_connections_for(step: NativeStepDict, parameter: ToolParameterT, prefix: Optional[str]): + parameter_name = parameter.name + state_path = parameter_name if prefix is None else f"{prefix}|{parameter_name}" + step.get("input_connections", {}) + return step.get(state_path) diff --git a/lib/galaxy/workflow/gx_validator.py b/lib/galaxy/workflow/gx_validator.py new file mode 100644 index 000000000000..37276457436a --- /dev/null +++ b/lib/galaxy/workflow/gx_validator.py @@ -0,0 +1,63 @@ +""""A validator for Galaxy workflows that is hooked up to Galaxy internals. + +The interface is designed to be usable from the tool shed for external tooling, +but for internal tooling - Galaxy should have its own tool available. +""" + +from typing import ( + Dict, + Optional, +) + +from galaxy.tool_util.models import ( + parse_tool, + ParsedTool, +) +from galaxy.tool_util.version import parse_version +from galaxy.tool_util.version_util import AnyVersionT +from galaxy.tool_util.workflow_state import ( + GetToolInfo, + validate_workflow as validate_workflow_generic, +) +from galaxy.tools.stock import stock_tool_sources + + +class GalaxyGetToolInfo(GetToolInfo): + stock_tools_by_version: Dict[str, Dict[AnyVersionT, ParsedTool]] + stock_tools_latest_version: Dict[str, AnyVersionT] + + def __init__(self): + # todo take in a toolbox in the future... + stock_tools: Dict[str, Dict[AnyVersionT, ParsedTool]] = {} + stock_tools_latest_version: Dict[str, AnyVersionT] = {} + for stock_tool in stock_tool_sources(): + id = stock_tool.parse_id() + version = stock_tool.parse_version() + if version is not None: + version_object = parse_version(version) + if id not in stock_tools: + stock_tools[id] = {} + if version_object is not None: + stock_tools_latest_version[id] = version_object + try: + stock_tools[id][version_object] = parse_tool(stock_tool) + except Exception: + pass + if version_object and version_object > stock_tools_latest_version[id]: + stock_tools_latest_version[id] = version_object + self.stock_tools = stock_tools + self.stock_tools_latest_version = stock_tools_latest_version + + def get_tool_info(self, tool_id: str, tool_version: Optional[str]) -> ParsedTool: + if tool_version is not None: + return self.stock_tools[tool_id][parse_version(tool_version)] + else: + latest_verison = self.stock_tools_latest_version[tool_id] + return self.stock_tools[tool_id][latest_verison] + + +GET_TOOL_INFO = GalaxyGetToolInfo() + + +def validate_workflow(as_dict): + return validate_workflow_generic(as_dict, get_tool_info=GET_TOOL_INFO) diff --git a/packages/tool_util/setup.cfg b/packages/tool_util/setup.cfg index eabca1a6e1b9..9db8b66ff45d 100644 --- a/packages/tool_util/setup.cfg +++ b/packages/tool_util/setup.cfg @@ -34,6 +34,7 @@ version = 24.2.dev0 include_package_data = True install_requires = galaxy-util>=22.1 + gxformat2 conda-package-streaming lxml!=4.2.2 MarkupSafe diff --git a/test/unit/tool_util/workflow_state/test_workflow_state_helpers.py b/test/unit/tool_util/workflow_state/test_workflow_state_helpers.py new file mode 100644 index 000000000000..e8544e22fbb2 --- /dev/null +++ b/test/unit/tool_util/workflow_state/test_workflow_state_helpers.py @@ -0,0 +1,23 @@ +from galaxy.tool_util.parameters import repeat_inputs_to_array + + +def test_repeat_inputs_to_array(): + rval = repeat_inputs_to_array( + "repeatfoo", + { + "moo": "cow", + }, + ) + assert not rval + test_state: dict = { + "moo": "cow", + "repeatfoo_0|moocow": ["moo"], + "repeatfoo_2|moocow": ["cow"], + } + rval = repeat_inputs_to_array("repeatfoo", test_state) + assert len(rval) == 3 + assert "repeatfoo_0|moocow" in rval[0] + assert "repeatfoo_0|moocow" not in rval[1] + assert "repeatfoo_0|moocow" not in rval[2] + assert "repeatfoo_2|moocow" not in rval[1] + assert "repeatfoo_2|moocow" in rval[2] diff --git a/test/unit/workflows/invalid/extra_attribute.gxwf.yml b/test/unit/workflows/invalid/extra_attribute.gxwf.yml new file mode 100644 index 000000000000..6ae50799394c --- /dev/null +++ b/test/unit/workflows/invalid/extra_attribute.gxwf.yml @@ -0,0 +1,15 @@ +class: GalaxyWorkflow +inputs: + input: + type: int +outputs: + output: + outputSource: the_step/output +steps: + the_step: + tool_id: gx_int + tool_version: "1.0.0" + state: + parameter2: 6 + in: + parameter: input diff --git a/test/unit/workflows/invalid/missing_link.gxwf.yml b/test/unit/workflows/invalid/missing_link.gxwf.yml new file mode 100644 index 000000000000..526b40f6f502 --- /dev/null +++ b/test/unit/workflows/invalid/missing_link.gxwf.yml @@ -0,0 +1,11 @@ +class: GalaxyWorkflow +inputs: + input: + type: data +outputs: + output: + outputSource: the_step/output +steps: + the_step: + tool_id: gx_data + tool_version: "1.0.0" diff --git a/test/unit/workflows/invalid/wrong_link_name.gxwf.yml b/test/unit/workflows/invalid/wrong_link_name.gxwf.yml new file mode 100644 index 000000000000..f0e0e8d12004 --- /dev/null +++ b/test/unit/workflows/invalid/wrong_link_name.gxwf.yml @@ -0,0 +1,13 @@ +class: GalaxyWorkflow +inputs: + input: + type: int +outputs: + output: + outputSource: the_step/output +steps: + the_step: + tool_id: gx_int + tool_version: "1.0.0" + in: + parameterx: input diff --git a/test/unit/workflows/test_workflow_state_conversion.py b/test/unit/workflows/test_workflow_state_conversion.py new file mode 100644 index 000000000000..e73b73b6aa55 --- /dev/null +++ b/test/unit/workflows/test_workflow_state_conversion.py @@ -0,0 +1,16 @@ +from galaxy.tool_util.workflow_state.convert import ( + convert_state_to_format2, + Format2State, +) +from galaxy.workflow.gx_validator import GET_TOOL_INFO +from .test_workflow_validation import base_package_workflow_as_dict + + +def convert_state(native_step_dict: dict) -> Format2State: + return convert_state_to_format2(native_step_dict, GET_TOOL_INFO) + + +def test_simple_convert(): + workflow_dict = base_package_workflow_as_dict("test_workflow_1.ga") + cat_step = workflow_dict["steps"]["2"] + convert_state(cat_step) diff --git a/test/unit/workflows/test_workflow_validation.py b/test/unit/workflows/test_workflow_validation.py new file mode 100644 index 000000000000..33c93ede6df5 --- /dev/null +++ b/test/unit/workflows/test_workflow_validation.py @@ -0,0 +1,80 @@ +import os +from typing import Optional + +from gxformat2.yaml import ordered_load + +from galaxy.util import galaxy_directory +from galaxy.workflow.gx_validator import validate_workflow + +TEST_WORKFLOW_DIRECTORY = os.path.join(galaxy_directory(), "lib", "galaxy_test", "workflow") +TEST_BASE_DATA_DIRECTORY = os.path.join(galaxy_directory(), "lib", "galaxy_test", "base", "data") +SCRIPT_DIRECTORY = os.path.abspath(os.path.dirname(__file__)) + + +def test_validate_simple_functional_test_case_workflow(): + validate_workflow(framework_test_workflow_as_dict("multiple_versions")) + validate_workflow(framework_test_workflow_as_dict("zip_collection")) + validate_workflow(framework_test_workflow_as_dict("empty_collection_sort")) + validate_workflow(framework_test_workflow_as_dict("flatten_collection")) + validate_workflow(framework_test_workflow_as_dict("flatten_collection_over_execution")) + + +def test_validate_native_workflows(): + validate_workflow(base_package_workflow_as_dict("test_workflow_two_random_lines.ga")) + # disconnected input... + # validate_workflow(base_package_workflow_as_dict("test_workflow_topoambigouity.ga")) + # double nested JSON? + # validate_workflow(base_package_workflow_as_dict("test_Workflow_map_reduce_pause.ga")) + # handle subworkflows... + # validate_workflow(base_package_workflow_as_dict("test_subworkflow_with_integer_input.ga")) + # handle gx_text.... + # validate_workflow(base_package_workflow_as_dict("test_workflow_batch.ga")) + + +def test_validate_unit_test_workflows(): + validate_workflow(unit_test_workflow_as_dict("valid/simple_int")) + validate_workflow(unit_test_workflow_as_dict("valid/simple_data")) + + +def test_invalidate_with_extra_attribute(): + e = _assert_validation_failure("invalid/extra_attribute") + assert "parameter2" in str(e) + + +def test_invalidate_with_wrong_link_name(): + e = _assert_validation_failure("invalid/wrong_link_name") + assert "parameterx" in str(e) + + +def test_invalidate_with_missing_link(): + e = _assert_validation_failure("invalid/missing_link") + assert "parameter" in str(e) + assert "type=missing" in str(e) + + +def _assert_validation_failure(workflow_name: str) -> Exception: + as_dict = unit_test_workflow_as_dict(workflow_name) + exc: Optional[Exception] = None + try: + validate_workflow(as_dict) + except Exception as e: + exc = e + assert exc, f"Target workflow ({workflow_name}) did not failure validation as expected." + return exc + + +def base_package_workflow_as_dict(file_name: str) -> dict: + return _load(os.path.join(TEST_BASE_DATA_DIRECTORY, file_name)) + + +def unit_test_workflow_as_dict(workflow_name: str) -> dict: + return _load(os.path.join(SCRIPT_DIRECTORY, f"{workflow_name}.gxwf.yml")) + + +def framework_test_workflow_as_dict(workflow_name: str) -> dict: + return _load(os.path.join(TEST_WORKFLOW_DIRECTORY, f"{workflow_name}.gxwf.yml")) + + +def _load(path: str) -> dict: + with open(path) as f: + return ordered_load(f) diff --git a/test/unit/workflows/test_workflow_validation_helpers.py b/test/unit/workflows/test_workflow_validation_helpers.py new file mode 100644 index 000000000000..af74e2b32a5f --- /dev/null +++ b/test/unit/workflows/test_workflow_validation_helpers.py @@ -0,0 +1,13 @@ +from galaxy.workflow.gx_validator import GET_TOOL_INFO + + +def test_get_tool(): + parsed_tool = GET_TOOL_INFO.get_tool_info("cat1", "1.0.0") + assert parsed_tool + assert parsed_tool.id == "cat1" + assert parsed_tool.version == "1.0.0" + + parsed_tool = GET_TOOL_INFO.get_tool_info("cat1", None) + assert parsed_tool + assert parsed_tool.id == "cat1" + assert parsed_tool.version == "1.0.0" diff --git a/test/unit/workflows/valid/simple_data.gxwf.yml b/test/unit/workflows/valid/simple_data.gxwf.yml new file mode 100644 index 000000000000..44f0a90f3dd9 --- /dev/null +++ b/test/unit/workflows/valid/simple_data.gxwf.yml @@ -0,0 +1,13 @@ +class: GalaxyWorkflow +inputs: + input: + type: data +outputs: + output: + outputSource: the_step/output +steps: + the_step: + tool_id: gx_data + tool_version: "1.0.0" + in: + parameter: input diff --git a/test/unit/workflows/valid/simple_int.gxwf.yml b/test/unit/workflows/valid/simple_int.gxwf.yml new file mode 100644 index 000000000000..d7c53f78d0a6 --- /dev/null +++ b/test/unit/workflows/valid/simple_int.gxwf.yml @@ -0,0 +1,13 @@ +class: GalaxyWorkflow +inputs: + input: + type: int +outputs: + output: + outputSource: the_step/output +steps: + the_step: + tool_id: gx_int + tool_version: "1.0.0" + in: + parameter: input