From 9268ea0065f5eb9a5357b76d1323d1a68bf16535 Mon Sep 17 00:00:00 2001 From: Francis Charette Migneault Date: Wed, 27 Mar 2024 14:53:50 -0400 Subject: [PATCH] fix all unittests for pystac ml-aoi extension --- pystac_ml_aoi/extensions/ml_aoi.py | 90 +++++++++++++++++++++++------- tests/test_pystac_extension.py | 31 +++++++++- 2 files changed, 98 insertions(+), 23 deletions(-) diff --git a/pystac_ml_aoi/extensions/ml_aoi.py b/pystac_ml_aoi/extensions/ml_aoi.py index a0d89f5..90a69a3 100644 --- a/pystac_ml_aoi/extensions/ml_aoi.py +++ b/pystac_ml_aoi/extensions/ml_aoi.py @@ -66,7 +66,13 @@ ML_AOI_PROPERTY = f"{ML_AOI_SCHEMA_ID}_".replace("-", "_") -class ML_AOI_Split(StringEnum): +class ExtendedEnum(StringEnum): + @classmethod + def values(cls) -> List[str]: + return list(cls.__members__.values()) + + +class ML_AOI_Split(ExtendedEnum): TRAIN: Literal["train"] = "train" VALIDATE: Literal["validate"] = "validate" TEST: Literal["test"] = "test" @@ -82,7 +88,7 @@ class ML_AOI_Split(StringEnum): ] -class ML_AOI_Role(StringEnum): +class ML_AOI_Role(ExtendedEnum): LABEL: Literal["label"] = "label" FEATURE: Literal["feature"] = "feature" @@ -96,7 +102,7 @@ class ML_AOI_Role(StringEnum): ] -class ML_AOI_Resampling(StringEnum): +class ML_AOI_Resampling(ExtendedEnum): NEAR: Literal["near"] = "near" BILINEAR: Literal["bilinear"] = "bilinear" CUBIC: Literal["cubic"] = "cubic" @@ -175,7 +181,7 @@ class ML_AOI_ItemProperties(ML_AOI_BaseFields, validate_assignment=True): class ML_AOI_CollectionFields(ML_AOI_BaseFields, validate_assignment=True): """ML-AOI properties for STAC Collections.""" - split: Optional[ML_AOI_Split] # split is required since it is the only available field + split: Optional[List[ML_AOI_Split]] # split is required since it is the only available field class ML_AOI_AssetFields(ML_AOI_BaseFields, validate_assignment=True): @@ -209,11 +215,18 @@ class ML_AOI_Extension( abc.ABC, ): @abc.abstractmethod - def get_ml_aoi_property(self, prop_name: str, _ml_aoi_required: bool) -> Any: + def get_ml_aoi_property(self, prop_name: str, *, _ml_aoi_required: bool) -> Any: raise NotImplementedError @abc.abstractmethod - def set_ml_aoi_property(self, prop_name: str, value: Any, _ml_aoi_required: bool, pop_if_none: bool = True) -> None: + def set_ml_aoi_property( + self, + prop_name: str, + value: Any, + pop_if_none: bool = True, + *, + _ml_aoi_required: bool, + ) -> None: raise NotImplementedError def __getitem__(self, prop_name): @@ -231,7 +244,7 @@ def _is_ml_aoi_property(cls, prop_name: str): ) @classmethod - def _retrieve_ml_aoi_property(cls, prop_name: str, _ml_aoi_required: bool) -> Optional[FieldInfo]: + def _retrieve_ml_aoi_property(cls, prop_name: str, *, _ml_aoi_required: bool) -> Optional[FieldInfo]: if not _ml_aoi_required and not cls._is_ml_aoi_property(prop_name): return try: @@ -350,18 +363,19 @@ class ML_AOI_PropertiesExtension( ML_AOI_Extension[T], abc.ABC, ): - def get_ml_aoi_property(self, prop_name: str, _ml_aoi_required: bool = True) -> list[Any]: - self._retrieve_ml_aoi_property(prop_name, _ml_aoi_required) + def get_ml_aoi_property(self, prop_name: str, *, _ml_aoi_required: bool = True) -> list[Any]: + self._retrieve_ml_aoi_property(prop_name, _ml_aoi_required=_ml_aoi_required) return self.properties.get(prop_name) def set_ml_aoi_property( self, prop_name: str, value: Any, + *, _ml_aoi_required: bool = True, pop_if_none: bool = True, ) -> None: - field = self._retrieve_ml_aoi_property(prop_name, _ml_aoi_required) + field = self._retrieve_ml_aoi_property(prop_name, _ml_aoi_required=_ml_aoi_required) if field: # validation must be performed against the non-aliased field # then, apply the alias for the actual assignment of the property @@ -402,22 +416,36 @@ def __init__(self, item: pystac.Item): def get_assets( self, - roles: Optional[List[ML_AOI_Role]] = None, + role: Optional[Union[ML_AOI_Role, List[ML_AOI_Role]]] = None, reference_grid: Optional[bool] = None, resampling_method: Optional[str] = None, + media_type: Optional[Union[pystac.MediaType, str]] = None, + asset_role: Optional[str] = None, ) -> dict[str, pystac.Asset]: """Get the item's assets where ``ml-aoi`` fields are matched. + Args: + role: The ML-AOI role, or a list of roles to filter Assets. + Note, this should not be confused with Asset 'roles'. + reference_grid: Filter Assets that contain the specified ML-AOI reference grid value. + resampling_method: Filter Assets that contain the specified ML-AOI resampling method. + media_type: Filter Assets with the given media-type. + asset_role: Filter Assets which contains the specified role (not to be confused with the ML-AOI role). Returns: Dict[str, Asset]: A dictionary of assets that matched filters. """ - roles = roles or get_args(ML_AOI_Role) + # since this method could be used for assets that refer to other extensions as well, + # filters must not limit themselves to ML-AOI fields + # if values are 'None', we must consider them as 'ignore' instead of 'any of' allowed values return { key: asset - for key, asset in self.item.get_assets().items() - if any( - role in (asset.extra_fields.get(add_ml_aoi_prefix("role")) or []) - for role in roles + for key, asset in self.item.get_assets(media_type=media_type, role=asset_role).items() + if ( + not role or + any( + asset_role in (asset.extra_fields.get(add_ml_aoi_prefix("role")) or []) + for asset_role in ([role] if isinstance(role, ML_AOI_Role) else role) + ) ) and ( reference_grid is None or @@ -521,10 +549,12 @@ def __init__(self, collection: pystac.Collection): def get_ml_aoi_property( self, prop_name: str, + *, _ml_aoi_required: bool = True, ) -> AnySummary: - found = self._retrieve_ml_aoi_property(prop_name, _ml_aoi_required) - if found or _ml_aoi_required: + field = self._retrieve_ml_aoi_property(prop_name, _ml_aoi_required=_ml_aoi_required) + if field or _ml_aoi_required: + prop_name = field.alias or prop_name return self.summaries.get_list(prop_name) return object.__getattribute__(self, prop_name) @@ -532,6 +562,8 @@ def set_ml_aoi_property( self, prop_name: str, value: AnySummary, + pop_if_none: bool = False, + *, _ml_aoi_required: bool = True, ) -> None: # if _ml_aoi_required and not hasattr(value, "__iter__") or isinstance(value, str): @@ -540,11 +572,12 @@ def set_ml_aoi_property( # f"received '{value.__class__.__name__}' is invalid." # ) - field = self._retrieve_ml_aoi_property(prop_name, _ml_aoi_required) + field = self._retrieve_ml_aoi_property(prop_name, _ml_aoi_required=_ml_aoi_required) if field or _ml_aoi_required: # prop_name = field.alias or prop_name - for summary in value: - self._validate_ml_aoi_property(prop_name, summary) + if not isinstance(value, (list, pystac.RangeSummary, dict)): + value = [value] + self._validate_ml_aoi_property(prop_name, value) prop_name = field.alias or prop_name super()._set_summary(prop_name, value) else: @@ -571,10 +604,25 @@ def __init__(self, collection: pystac.Collection): ML_AOI_SummariesExtension.__init__(self, collection) self.collection = collection self.properties = collection.extra_fields + self.collection.set_self_href = ML_AOI_CollectionExtension.set_self_href # override hook def __repr__(self) -> str: return f"" + def set_self_href(self, href: Optional[str]) -> None: + """Sets the absolute HREF that is represented by the ``rel == 'self'`` :class:`~pystac.Link`. + + Adds the relevant ML-AOI role applicable for the Collection. + """ + pystac.Collection.set_self_href(self.collection, href) + ml_aoi_split = self.get_ml_aoi_property("split") + if not ml_aoi_split: + return + for link in self.collection.links: + if link.rel == pystac.RelType.SELF: + field_name = add_ml_aoi_prefix("split") + link.extra_fields[field_name] = ml_aoi_split[0] + class ML_AOI_ExtensionHooks(ExtensionHooks): schema_uri: str = SCHEMA_URI diff --git a/tests/test_pystac_extension.py b/tests/test_pystac_extension.py index 6dba8ee..4d82a09 100644 --- a/tests/test_pystac_extension.py +++ b/tests/test_pystac_extension.py @@ -177,7 +177,7 @@ def test_ml_aoi_pystac_collection_with_apply_method( assert ML_AOI_SCHEMA_URI in collection.stac_extensions collection.validate() ml_aoi_col_json = collection.to_dict() - assert ml_aoi_col_json["summaries"] == [{"ml-aoi:split": "train"}] + assert ml_aoi_col_json["summaries"] == {"ml-aoi:split": ["train"]} def test_ml_aoi_pystac_collection_with_field_property( @@ -193,7 +193,30 @@ def test_ml_aoi_pystac_collection_with_field_property( assert ML_AOI_SCHEMA_URI in collection.stac_extensions collection.validate() ml_aoi_col_json = collection.to_dict() - assert ml_aoi_col_json["summaries"] == [{"ml-aoi:split": "train"}] + assert ml_aoi_col_json["summaries"] == {"ml-aoi:split": ["train"]} + + +def test_ml_aoi_pystac_collection_self_link_role( + collection: pystac.Collection, + stac_validator: pystac.validation.stac_validator.JsonSchemaSTACValidator, +) -> None: + """ + Validate extending a STAC Collection with ML-AOI extension. + """ + assert ML_AOI_SCHEMA_URI not in collection.stac_extensions + ml_aoi_col = ML_AOI_Extension.ext(collection, add_if_missing=True) + ml_aoi_col.split = [ML_AOI_Split.TEST] + assert ML_AOI_SCHEMA_URI in collection.stac_extensions + ml_aoi_col.set_self_href("https://example.com/collections/test") + ml_aoi_col_json = collection.to_dict() + assert ml_aoi_col_json["summaries"] == {"ml-aoi:split": ["test"]} + ml_aoi_col_links = ml_aoi_col_json["links"] + assert any(link == { + "rel": "self", + "href": "https://example.com/collections/test", + "type": "application/json", + "ml-aoi:split": "test", + } for link in ml_aoi_col_links), ml_aoi_col_links def test_ml_aoi_pystac_item_with_apply_method( @@ -238,8 +261,12 @@ def test_ml_aoi_pystac_item_filter_assets( assert ML_AOI_SCHEMA_URI not in item.stac_extensions ml_aoi_item = cast(ML_AOI_ItemExtension, ML_AOI_Extension.ext(item, add_if_missing=True)) ml_aoi_item.split = ML_AOI_Split.TRAIN + assets = ml_aoi_item.get_assets() + assert len(assets) == 2 and list(assets) == ["label", "raster"] assets = ml_aoi_item.get_assets(reference_grid=True) assert len(assets) == 1 and "raster" in assets + assets = ml_aoi_item.get_assets(role=ML_AOI_Role.LABEL) + assert len(assets) == 1 and "label" in assets if __name__ == "__main__":