Skip to content

Commit

Permalink
Merge pull request #18743 from jmchilton/parameter_models_rework_test…
Browse files Browse the repository at this point in the history
…_case_xml

Improvements to parameter models for test case inputs
  • Loading branch information
jmchilton committed Sep 19, 2024
2 parents 844beb6 + fd4b182 commit 0863221
Show file tree
Hide file tree
Showing 17 changed files with 427 additions and 92 deletions.
2 changes: 1 addition & 1 deletion lib/galaxy/tool_util/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def parse_tool(tool_source: ToolSource) -> ParsedTool:
version = tool_source.parse_version()
name = tool_source.parse_name()
description = tool_source.parse_description()
inputs = input_models_for_tool_source(tool_source).input_models
inputs = input_models_for_tool_source(tool_source).parameters
outputs = from_tool_source(tool_source)
citations = tool_source.parse_citations()
license = tool_source.parse_license()
Expand Down
2 changes: 2 additions & 0 deletions lib/galaxy/tool_util/parameters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
repeat_inputs_to_array,
validate_explicit_conditional_test_value,
visit_input_values,
VISITOR_NO_REPLACEMENT,
)

__all__ = (
Expand Down Expand Up @@ -116,6 +117,7 @@
"keys_starting_with",
"visit_input_values",
"repeat_inputs_to_array",
"VISITOR_NO_REPLACEMENT",
"decode",
"encode",
"WorkflowStepToolState",
Expand Down
26 changes: 24 additions & 2 deletions lib/galaxy/tool_util/parameters/case.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@
from packaging.version import Version

from galaxy.tool_util.parser.interface import (
TestCollectionDef,
ToolSource,
ToolSourceTest,
ToolSourceTestInput,
ToolSourceTestInputs,
xml_data_input_to_json,
XmlTestCollectionDefDict,
)
from galaxy.util import asbool
from .factory import input_models_for_tool_source
Expand All @@ -25,6 +28,7 @@
ConditionalWhen,
DataCollectionParameterModel,
DataColumnParameterModel,
DataParameterModel,
FloatParameterModel,
IntegerParameterModel,
RepeatParameterModel,
Expand Down Expand Up @@ -249,8 +253,26 @@ def _merge_into_state(
else:
test_input = _input_for(state_path, inputs)
if test_input is not None:
input_value: Any
if isinstance(tool_input, (DataCollectionParameterModel,)):
input_value = test_input.get("attributes", {}).get("collection")
input_value = TestCollectionDef.from_dict(
cast(XmlTestCollectionDefDict, test_input.get("attributes", {}).get("collection"))
).test_format_to_dict()
elif isinstance(tool_input, (DataParameterModel,)):
data_tool_input = cast(DataParameterModel, tool_input)
if data_tool_input.multiple:
value = test_input["value"]
input_value_list = []
if value:
test_input_values = cast(str, value).split(",")
for test_input_value in test_input_values:
instance_test_input = test_input.copy()
instance_test_input["value"] = test_input_value
input_value = xml_data_input_to_json(test_input)
input_value_list.append(input_value)
input_value = input_value_list
else:
input_value = xml_data_input_to_json(test_input)
else:
input_value = test_input["value"]
input_value = legacy_from_string(tool_input, input_value, warnings, profile)
Expand Down Expand Up @@ -299,6 +321,6 @@ def validate_test_cases_for_tool_source(
test_cases: List[ToolSourceTest] = tool_source.parse_tests_to_dict()["tests"]
results_by_test: List[TestCaseStateValidationResult] = []
for test_case in test_cases:
validation_result = test_case_validation(test_case, tool_parameter_bundle.input_models, profile)
validation_result = test_case_validation(test_case, tool_parameter_bundle.parameters, profile)
results_by_test.append(validation_result)
return results_by_test
4 changes: 2 additions & 2 deletions lib/galaxy/tool_util/parameters/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def _from_input_source_cwl(input_source: CwlInputSource) -> ToolParameterT:


def input_models_from_json(json: List[Dict[str, Any]]) -> ToolParameterBundle:
return ToolParameterBundleModel(input_models=json)
return ToolParameterBundleModel(parameters=json)


def tool_parameter_bundle_from_json(json: Dict[str, Any]) -> ToolParameterBundleModel:
Expand All @@ -328,7 +328,7 @@ def tool_parameter_bundle_from_json(json: Dict[str, Any]) -> ToolParameterBundle

def input_models_for_tool_source(tool_source: ToolSource) -> ToolParameterBundleModel:
pages = tool_source.parse_input_pages()
return ToolParameterBundleModel(input_models=input_models_for_pages(pages))
return ToolParameterBundleModel(parameters=input_models_for_pages(pages))


def input_models_for_pages(pages: PagesSource) -> List[ToolParameterT]:
Expand Down
32 changes: 15 additions & 17 deletions lib/galaxy/tool_util/parameters/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@
from galaxy.exceptions import RequestParameterInvalidException
from galaxy.tool_util.parser.interface import (
DrillDownOptionsDict,
TestCollectionDefDict,
JsonTestCollectionDefDict,
JsonTestDatasetDefDict,
)
from ._types import (
cast_as_type,
Expand Down Expand Up @@ -312,9 +313,9 @@ def py_type_internal(self) -> Type:
def py_type_test_case(self) -> Type:
base_model: Type
if self.multiple:
base_model = str
base_model = list_type(JsonTestDatasetDefDict)
else:
base_model = str
base_model = JsonTestDatasetDefDict
return optional_if_needed(base_model, self.optional)

def pydantic_template(self, state_representation: StateRepresentationT) -> DynamicModelInformation:
Expand Down Expand Up @@ -372,7 +373,7 @@ def pydantic_template(self, state_representation: StateRepresentationT) -> Dynam
elif state_representation == "workflow_step_linked":
return dynamic_model_information_from_py_type(self, ConnectedValue)
elif state_representation == "test_case_xml":
return dynamic_model_information_from_py_type(self, TestCollectionDefDict)
return dynamic_model_information_from_py_type(self, JsonTestCollectionDefDict)
else:
raise NotImplementedError(
f"Have not implemented data collection parameter models for state representation {state_representation}"
Expand Down Expand Up @@ -1164,12 +1165,11 @@ class ToolParameterModel(RootModel):
class ToolParameterBundle(Protocol):
"""An object having a dictionary of input models (i.e. a 'Tool')"""

# TODO: rename to parameters to align with ConditionalWhen and Repeat.
input_models: List[ToolParameterT]
parameters: List[ToolParameterT]


class ToolParameterBundleModel(BaseModel):
input_models: List[ToolParameterT]
parameters: List[ToolParameterT]


def to_simple_model(input_parameter: Union[ToolParameterModel, ToolParameterT]) -> ToolParameterT:
Expand All @@ -1180,10 +1180,8 @@ def to_simple_model(input_parameter: Union[ToolParameterModel, ToolParameterT])
return cast(ToolParameterT, input_parameter)


def simple_input_models(
input_models: Union[List[ToolParameterModel], List[ToolParameterT]]
) -> Iterable[ToolParameterT]:
return [to_simple_model(m) for m in input_models]
def simple_input_models(parameters: Union[List[ToolParameterModel], List[ToolParameterT]]) -> Iterable[ToolParameterT]:
return [to_simple_model(m) for m in parameters]


def create_model_strict(*args, **kwd) -> Type[BaseModel]:
Expand All @@ -1194,27 +1192,27 @@ def create_model_strict(*args, **kwd) -> Type[BaseModel]:


def create_request_model(tool: ToolParameterBundle, name: str = "DynamicModelForTool") -> Type[BaseModel]:
return create_field_model(tool.input_models, name, "request")
return create_field_model(tool.parameters, name, "request")


def create_request_internal_model(tool: ToolParameterBundle, name: str = "DynamicModelForTool") -> Type[BaseModel]:
return create_field_model(tool.input_models, name, "request_internal")
return create_field_model(tool.parameters, name, "request_internal")


def create_job_internal_model(tool: ToolParameterBundle, name: str = "DynamicModelForTool") -> Type[BaseModel]:
return create_field_model(tool.input_models, name, "job_internal")
return create_field_model(tool.parameters, name, "job_internal")


def create_test_case_model(tool: ToolParameterBundle, name: str = "DynamicModelForTool") -> Type[BaseModel]:
return create_field_model(tool.input_models, name, "test_case_xml")
return create_field_model(tool.parameters, name, "test_case_xml")


def create_workflow_step_model(tool: ToolParameterBundle, name: str = "DynamicModelForTool") -> Type[BaseModel]:
return create_field_model(tool.input_models, name, "workflow_step")
return create_field_model(tool.parameters, name, "workflow_step")


def create_workflow_step_linked_model(tool: ToolParameterBundle, name: str = "DynamicModelForTool") -> Type[BaseModel]:
return create_field_model(tool.input_models, name, "workflow_step_linked")
return create_field_model(tool.parameters, name, "workflow_step_linked")


def create_field_model(
Expand Down
38 changes: 19 additions & 19 deletions lib/galaxy/tool_util/parameters/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ def __init__(self, input_state: Dict[str, Any]):
def _validate(self, pydantic_model: Type[BaseModel]) -> None:
validate_against_model(pydantic_model, self.input_state)

def validate(self, input_models: HasToolParameters) -> None:
base_model = self.parameter_model_for(input_models)
def validate(self, parameters: HasToolParameters) -> None:
base_model = self.parameter_model_for(parameters)
if base_model is None:
raise NotImplementedError(
f"Validating tool state against state representation {self.state_representation} is not implemented."
Expand All @@ -53,64 +53,64 @@ def state_representation(self) -> StateRepresentationT:
"""Get state representation of the inputs."""

@classmethod
def parameter_model_for(cls, input_models: HasToolParameters) -> Type[BaseModel]:
def parameter_model_for(cls, parameters: HasToolParameters) -> Type[BaseModel]:
bundle: ToolParameterBundle
if isinstance(input_models, list):
bundle = ToolParameterBundleModel(input_models=input_models)
if isinstance(parameters, list):
bundle = ToolParameterBundleModel(parameters=parameters)
else:
bundle = input_models
bundle = parameters
return cls._parameter_model_for(bundle)

@classmethod
@abstractmethod
def _parameter_model_for(cls, input_models: ToolParameterBundle) -> Type[BaseModel]:
def _parameter_model_for(cls, parameters: ToolParameterBundle) -> Type[BaseModel]:
"""Return a model type for this tool state kind."""


class RequestToolState(ToolState):
state_representation: Literal["request"] = "request"

@classmethod
def _parameter_model_for(cls, input_models: ToolParameterBundle) -> Type[BaseModel]:
return create_request_model(input_models)
def _parameter_model_for(cls, parameters: ToolParameterBundle) -> Type[BaseModel]:
return create_request_model(parameters)


class RequestInternalToolState(ToolState):
state_representation: Literal["request_internal"] = "request_internal"

@classmethod
def _parameter_model_for(cls, input_models: ToolParameterBundle) -> Type[BaseModel]:
return create_request_internal_model(input_models)
def _parameter_model_for(cls, parameters: ToolParameterBundle) -> Type[BaseModel]:
return create_request_internal_model(parameters)


class JobInternalToolState(ToolState):
state_representation: Literal["job_internal"] = "job_internal"

@classmethod
def _parameter_model_for(cls, input_models: ToolParameterBundle) -> Type[BaseModel]:
return create_job_internal_model(input_models)
def _parameter_model_for(cls, parameters: ToolParameterBundle) -> Type[BaseModel]:
return create_job_internal_model(parameters)


class TestCaseToolState(ToolState):
state_representation: Literal["test_case_xml"] = "test_case_xml"

@classmethod
def _parameter_model_for(cls, input_models: ToolParameterBundle) -> Type[BaseModel]:
def _parameter_model_for(cls, parameters: ToolParameterBundle) -> Type[BaseModel]:
# implement a test case model...
return create_test_case_model(input_models)
return create_test_case_model(parameters)


class WorkflowStepToolState(ToolState):
state_representation: Literal["workflow_step"] = "workflow_step"

@classmethod
def _parameter_model_for(cls, input_models: ToolParameterBundle) -> Type[BaseModel]:
return create_workflow_step_model(input_models)
def _parameter_model_for(cls, parameters: ToolParameterBundle) -> Type[BaseModel]:
return create_workflow_step_model(parameters)


class WorkflowStepLinkedToolState(ToolState):
state_representation: Literal["workflow_step_linked"] = "workflow_step_linked"

@classmethod
def _parameter_model_for(cls, input_models: ToolParameterBundle) -> Type[BaseModel]:
return create_workflow_step_linked_model(input_models)
def _parameter_model_for(cls, parameters: ToolParameterBundle) -> Type[BaseModel]:
return create_workflow_step_linked_model(parameters)
2 changes: 1 addition & 1 deletion lib/galaxy/tool_util/parameters/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def visit_input_values(
no_replacement_value=VISITOR_NO_REPLACEMENT,
) -> Dict[str, Any]:
return _visit_input_values(
simple_input_models(input_models.input_models),
simple_input_models(input_models.parameters),
tool_state.input_state,
callback=callback,
no_replacement_value=no_replacement_value,
Expand Down
Loading

0 comments on commit 0863221

Please sign in to comment.