diff --git a/src/lsst/cmservice/cli/options.py b/src/lsst/cmservice/cli/options.py index bf573a2b1..7e6c1d0d7 100644 --- a/src/lsst/cmservice/cli/options.py +++ b/src/lsst/cmservice/cli/options.py @@ -25,6 +25,82 @@ ] +class DictParamType(click.ParamType): + """Represents the dictionary type of a CLI parameter. + + Validates and converts values from the command line string or Python into + a Python dict. + - All key-value pairs must be separated by one semicolon. + - Key and value must be separated by one equal sign. + - Converts sequences separeted by dots into a list: list value items + must be separated by commas. + - Converts numbers to int. + + Usage: + >>> @click.option("--param", default=None, type=DictParamType()) + ... def command(param): + ... ... + + CLI: command --param='page=1; name=Items; rules=1, 2, three; extra=A,;' + + Example: + + >>> param_value = 'page=1; name=Items; rules=1, 2, three; extra=A,;' + >>> DictParamType().convert(param_value, None, None) + {'page': 1, 'name': 'Items', 'rules': [1, 2, 'three'], 'extra': ['A']}` + + """ + + name = "dictionary" + + def convert(self, cli_value, param, ctx): + """Converts CLI value to the dictionary structure. + + Args: + cli_value (Any): The value to convert. + param (click.Parameter | None): The parameter that is using this + type to convert its value. + ctx (click.Context | None): The current context that arrived + at this value. + + Returns: + dict: The validated and converted dictionary. + + Raises: + click.BadParameter: If the validation is failed. + """ + if isinstance(cli_value, dict): + return cli_value + try: + keyvalue_pairs = cli_value.rstrip(";").split(";") + result_dict = {} + for pair in keyvalue_pairs: + key, values = [item.strip() for item in pair.split("=")] + converted_values = [] + for value in values.split(","): + value = value.strip() + if value.isdigit(): + value = int(value) + converted_values.append(value) + + if len(converted_values) == 1: + result_dict[key] = converted_values[0] + elif len(converted_values) > 1 and converted_values[-1] == "": + result_dict[key] = converted_values[:-1] + else: + result_dict[key] = converted_values + return result_dict + except ValueError: + self.fail( + "All key-value pairs must be separated by one semicolon. " + "Key and value must be separated by one equal sign. " + "List value items must be separated by one comma. " + f"Key-value: {pair}.", + param, + ctx, + ) + + class EnumChoice(click.Choice): """A version of click.Choice specialized for enum types.""" @@ -114,7 +190,7 @@ class OutputEnum(Enum): update_dict = PartialOption( "--update_dict", - type=dict, + type=DictParamType(), help="Values to update", ) diff --git a/src/lsst/cmservice/client.py b/src/lsst/cmservice/client.py index 31590a743..1dbe414c0 100644 --- a/src/lsst/cmservice/client.py +++ b/src/lsst/cmservice/client.py @@ -217,7 +217,7 @@ def get_job_errors( fullname=fullname, ) query = "get/job/errors" - results = self._client.get(f"{query}", params=params.dict()).json() + results = self._client.get(f"{query}", parans=params.dict()).json() try: return parse_obj_as(list[models.ErrorInstance], results) except ValidationError as msg: @@ -229,7 +229,7 @@ def update_status( ) -> StatusEnum: query = "update/status" params = models.UpdateStatusQuery(**kwargs) - results = self._client.post(f"{query}", params=params.dict()).json() + results = self._client.post(f"{query}", data=params.dict()).json() try: return parse_obj_as(StatusEnum, results) except ValidationError as msg: @@ -242,7 +242,7 @@ def update_collections( query = "update/collections" kwargs["node_type"] = kwargs["node_type"].value params = models.UpdateNodeQuery(**kwargs) - results = self._client.post(f"{query}", params=params.dict()).json() + results = self._client.post(f"{query}", data=params.dict()).json() try: return parse_obj_as(dict, results) except ValidationError as msg: @@ -255,7 +255,7 @@ def update_data_dict( query = "update/data_dict" kwargs["node_type"] = kwargs["node_type"].value params = models.UpdateNodeQuery(**kwargs) - results = self._client.post(f"{query}", params=params.dict()).json() + results = self._client.post(f"{query}", data=params.dict()).json() try: return parse_obj_as(dict, results) except ValidationError as msg: @@ -268,7 +268,8 @@ def update_child_config( query = "update/child_config" kwargs["node_type"] = kwargs["node_type"].value params = models.UpdateNodeQuery(**kwargs) - results = self._client.post(f"{query}", params=params.dict()).json() + print(params.dict()) + results = self._client.post(f"{query}", data=params.json()).json() try: return parse_obj_as(dict, results) except ValidationError as msg: @@ -280,7 +281,7 @@ def add_groups( ) -> list[models.Group]: query = "add/groups" params = models.AddGroups(**kwargs) - results = self._client.post(f"{query}", params=params.dict()).json() + results = self._client.post(f"{query}", data=params.dict()).json() try: return parse_obj_as(list[models.Group], results) except ValidationError as msg: @@ -292,7 +293,7 @@ def add_steps( ) -> list[models.Group]: query = "add/groups" params = models.AddSteps(**kwargs) - results = self._client.post(f"{query}", params=params.dict()).json() + results = self._client.post(f"{query}", data=params.dict()).json() try: return parse_obj_as(list[models.Group], results) except ValidationError as msg: @@ -304,7 +305,7 @@ def add_campaign( ) -> models.Campaign: query = "add/campaign" params = models.CampaignCreate(**kwargs) - results = self._client.post(f"{query}", params=params.dict()).json() + results = self._client.post(f"{query}", data=params.dict()).json() try: return parse_obj_as(models.Campaign, results) except ValidationError as msg: @@ -316,7 +317,7 @@ def load_specification( ) -> models.Specification: query = "load/specification" params = models.SpecificationLoad(**kwargs) - results = self._client.post(f"{query}", params=params.dict()).json() + results = self._client.post(f"{query}", data=params.dict()).json() try: return parse_obj_as(models.Specification, results) except ValidationError as msg: @@ -328,7 +329,7 @@ def load_campaign( ) -> models.Campaign: query = "load/campaign" params = models.LoadAndCreateCampaign(**kwargs) - results = self._client.post(f"{query}", params=params.dict()).json() + results = self._client.post(f"{query}", data=params.dict()).json() try: return parse_obj_as(models.Campaign, results) except ValidationError as msg: @@ -340,7 +341,7 @@ def load_error_types( ) -> list[models.ErrorType]: query = "load/error_types" params = models.YamlFileQuery(**kwargs) - results = self._client.post(f"{query}", params=params.dict()).json() + results = self._client.post(f"{query}", data=params.dict()).json() try: return parse_obj_as(list[models.ErrorType], results) except ValidationError as msg: diff --git a/src/lsst/cmservice/common/enums.py b/src/lsst/cmservice/common/enums.py index 922086c23..eece05f00 100644 --- a/src/lsst/cmservice/common/enums.py +++ b/src/lsst/cmservice/common/enums.py @@ -103,7 +103,7 @@ class StatusEnum(enum.Enum): accepted = 5 # Completed, reviewed and accepted - rescued = 6 # Rescueable and rescue created + rescued = 6 # Rescueable and rescued """ # note that ordering of these Enums matters within the diff --git a/src/lsst/cmservice/db/error_instance.py b/src/lsst/cmservice/db/error_instance.py index 081686c9f..e687a628d 100644 --- a/src/lsst/cmservice/db/error_instance.py +++ b/src/lsst/cmservice/db/error_instance.py @@ -5,7 +5,7 @@ from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.schema import ForeignKey -from ..common.enums import ErrorAction, ErrorFlavor, ErrorSource +from ..common.enums import ErrorSource from .base import Base from .job import Job from .row import RowMixin @@ -25,8 +25,6 @@ class ErrorInstance(Base, RowMixin): job_id: Mapped[int | None] = mapped_column(ForeignKey("job.id", ondelete="CASCADE"), index=True) script_id: Mapped[int | None] = mapped_column(ForeignKey("script.id", ondelete="CASCADE"), index=True) source: Mapped[ErrorSource] = mapped_column() - flavor: Mapped[ErrorFlavor] = mapped_column() - action: Mapped[ErrorAction] = mapped_column() diagnostic_message: Mapped[str] = mapped_column() data: Mapped[Optional[dict | list]] = mapped_column(type_=JSON) diff --git a/src/lsst/cmservice/db/node.py b/src/lsst/cmservice/db/node.py index d74c1bf21..7d6b14fad 100644 --- a/src/lsst/cmservice/db/node.py +++ b/src/lsst/cmservice/db/node.py @@ -231,7 +231,7 @@ async def update_child_config( self.child_config.update(**kwargs) else: self.child_config = kwargs.copy() - await session.refresh(self) + await session.refresh(self, attribute_names=["child_config"]) return self.child_config async def update_collections( diff --git a/src/lsst/cmservice/models/error_instance.py b/src/lsst/cmservice/models/error_instance.py index 72c4931af..0dfd99fe6 100644 --- a/src/lsst/cmservice/models/error_instance.py +++ b/src/lsst/cmservice/models/error_instance.py @@ -1,13 +1,11 @@ from pydantic import BaseModel -from ..common.enums import ErrorAction, ErrorFlavor, ErrorSource +from ..common.enums import ErrorSource class ErrorInstanceBase(BaseModel): error_type_id: int | None source: ErrorSource - flavor: ErrorFlavor - action: ErrorAction diagnostic_message: str data: dict diff --git a/src/lsst/cmservice/routers/updates.py b/src/lsst/cmservice/routers/updates.py index 09e88f7f1..ed1e26062 100644 --- a/src/lsst/cmservice/routers/updates.py +++ b/src/lsst/cmservice/routers/updates.py @@ -5,6 +5,7 @@ from lsst.cmservice.db import interface from .. import models +from ..common.enums import NodeTypeEnum router = APIRouter( prefix="/update", @@ -25,7 +26,7 @@ async def update_status( result = await interface.update_status( session, query.fullname, - query.node_type, + NodeTypeEnum(query.node_type), query.status, ) return result @@ -44,7 +45,7 @@ async def update_collections( result = await interface.update_collections( session, query.fullname, - query.node_type, + NodeTypeEnum(query.node_type), **query.update_dict, ) return result @@ -63,7 +64,7 @@ async def update_child_config( result = await interface.update_child_config( session, query.fullname, - query.node_type, + NodeTypeEnum(query.node_type), **query.update_dict, ) return result @@ -82,7 +83,7 @@ async def update_data_dict( result = await interface.update_data_dict( session, query.fullname, - query.node_type, + NodeTypeEnum(query.node_type), **query.update_dict, ) return result