Skip to content

Commit

Permalink
fix all unittests for pystac ml-aoi extension
Browse files Browse the repository at this point in the history
  • Loading branch information
fmigneault committed Mar 27, 2024
1 parent ef2f130 commit 9268ea0
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 23 deletions.
90 changes: 69 additions & 21 deletions pystac_ml_aoi/extensions/ml_aoi.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,13 @@
ML_AOI_PROPERTY = f"{ML_AOI_SCHEMA_ID}_".replace("-", "_")


class ML_AOI_Split(StringEnum):
class ExtendedEnum(StringEnum):
@classmethod
def values(cls) -> List[str]:
return list(cls.__members__.values())


class ML_AOI_Split(ExtendedEnum):
TRAIN: Literal["train"] = "train"
VALIDATE: Literal["validate"] = "validate"
TEST: Literal["test"] = "test"
Expand All @@ -82,7 +88,7 @@ class ML_AOI_Split(StringEnum):
]


class ML_AOI_Role(StringEnum):
class ML_AOI_Role(ExtendedEnum):
LABEL: Literal["label"] = "label"
FEATURE: Literal["feature"] = "feature"

Expand All @@ -96,7 +102,7 @@ class ML_AOI_Role(StringEnum):
]


class ML_AOI_Resampling(StringEnum):
class ML_AOI_Resampling(ExtendedEnum):
NEAR: Literal["near"] = "near"
BILINEAR: Literal["bilinear"] = "bilinear"
CUBIC: Literal["cubic"] = "cubic"
Expand Down Expand Up @@ -175,7 +181,7 @@ class ML_AOI_ItemProperties(ML_AOI_BaseFields, validate_assignment=True):
class ML_AOI_CollectionFields(ML_AOI_BaseFields, validate_assignment=True):
"""ML-AOI properties for STAC Collections."""

split: Optional[ML_AOI_Split] # split is required since it is the only available field
split: Optional[List[ML_AOI_Split]] # split is required since it is the only available field


class ML_AOI_AssetFields(ML_AOI_BaseFields, validate_assignment=True):
Expand Down Expand Up @@ -209,11 +215,18 @@ class ML_AOI_Extension(
abc.ABC,
):
@abc.abstractmethod
def get_ml_aoi_property(self, prop_name: str, _ml_aoi_required: bool) -> Any:
def get_ml_aoi_property(self, prop_name: str, *, _ml_aoi_required: bool) -> Any:
raise NotImplementedError

@abc.abstractmethod
def set_ml_aoi_property(self, prop_name: str, value: Any, _ml_aoi_required: bool, pop_if_none: bool = True) -> None:
def set_ml_aoi_property(
self,
prop_name: str,
value: Any,
pop_if_none: bool = True,
*,
_ml_aoi_required: bool,
) -> None:
raise NotImplementedError

def __getitem__(self, prop_name):
Expand All @@ -231,7 +244,7 @@ def _is_ml_aoi_property(cls, prop_name: str):
)

@classmethod
def _retrieve_ml_aoi_property(cls, prop_name: str, _ml_aoi_required: bool) -> Optional[FieldInfo]:
def _retrieve_ml_aoi_property(cls, prop_name: str, *, _ml_aoi_required: bool) -> Optional[FieldInfo]:
if not _ml_aoi_required and not cls._is_ml_aoi_property(prop_name):
return
try:
Expand Down Expand Up @@ -350,18 +363,19 @@ class ML_AOI_PropertiesExtension(
ML_AOI_Extension[T],
abc.ABC,
):
def get_ml_aoi_property(self, prop_name: str, _ml_aoi_required: bool = True) -> list[Any]:
self._retrieve_ml_aoi_property(prop_name, _ml_aoi_required)
def get_ml_aoi_property(self, prop_name: str, *, _ml_aoi_required: bool = True) -> list[Any]:
self._retrieve_ml_aoi_property(prop_name, _ml_aoi_required=_ml_aoi_required)
return self.properties.get(prop_name)

def set_ml_aoi_property(
self,
prop_name: str,
value: Any,
*,
_ml_aoi_required: bool = True,
pop_if_none: bool = True,
) -> None:
field = self._retrieve_ml_aoi_property(prop_name, _ml_aoi_required)
field = self._retrieve_ml_aoi_property(prop_name, _ml_aoi_required=_ml_aoi_required)
if field:
# validation must be performed against the non-aliased field
# then, apply the alias for the actual assignment of the property
Expand Down Expand Up @@ -402,22 +416,36 @@ def __init__(self, item: pystac.Item):

def get_assets(
self,
roles: Optional[List[ML_AOI_Role]] = None,
role: Optional[Union[ML_AOI_Role, List[ML_AOI_Role]]] = None,
reference_grid: Optional[bool] = None,
resampling_method: Optional[str] = None,
media_type: Optional[Union[pystac.MediaType, str]] = None,
asset_role: Optional[str] = None,
) -> dict[str, pystac.Asset]:
"""Get the item's assets where ``ml-aoi`` fields are matched.
Args:
role: The ML-AOI role, or a list of roles to filter Assets.
Note, this should not be confused with Asset 'roles'.
reference_grid: Filter Assets that contain the specified ML-AOI reference grid value.
resampling_method: Filter Assets that contain the specified ML-AOI resampling method.
media_type: Filter Assets with the given media-type.
asset_role: Filter Assets which contains the specified role (not to be confused with the ML-AOI role).
Returns:
Dict[str, Asset]: A dictionary of assets that matched filters.
"""
roles = roles or get_args(ML_AOI_Role)
# since this method could be used for assets that refer to other extensions as well,
# filters must not limit themselves to ML-AOI fields
# if values are 'None', we must consider them as 'ignore' instead of 'any of' allowed values
return {
key: asset
for key, asset in self.item.get_assets().items()
if any(
role in (asset.extra_fields.get(add_ml_aoi_prefix("role")) or [])
for role in roles
for key, asset in self.item.get_assets(media_type=media_type, role=asset_role).items()
if (
not role or
any(
asset_role in (asset.extra_fields.get(add_ml_aoi_prefix("role")) or [])
for asset_role in ([role] if isinstance(role, ML_AOI_Role) else role)
)
)
and (
reference_grid is None or
Expand Down Expand Up @@ -521,17 +549,21 @@ def __init__(self, collection: pystac.Collection):
def get_ml_aoi_property(
self,
prop_name: str,
*,
_ml_aoi_required: bool = True,
) -> AnySummary:
found = self._retrieve_ml_aoi_property(prop_name, _ml_aoi_required)
if found or _ml_aoi_required:
field = self._retrieve_ml_aoi_property(prop_name, _ml_aoi_required=_ml_aoi_required)
if field or _ml_aoi_required:
prop_name = field.alias or prop_name
return self.summaries.get_list(prop_name)
return object.__getattribute__(self, prop_name)

def set_ml_aoi_property(
self,
prop_name: str,
value: AnySummary,
pop_if_none: bool = False,
*,
_ml_aoi_required: bool = True,
) -> None:
# if _ml_aoi_required and not hasattr(value, "__iter__") or isinstance(value, str):
Expand All @@ -540,11 +572,12 @@ def set_ml_aoi_property(
# f"received '{value.__class__.__name__}' is invalid."
# )

field = self._retrieve_ml_aoi_property(prop_name, _ml_aoi_required)
field = self._retrieve_ml_aoi_property(prop_name, _ml_aoi_required=_ml_aoi_required)
if field or _ml_aoi_required:
# prop_name = field.alias or prop_name
for summary in value:
self._validate_ml_aoi_property(prop_name, summary)
if not isinstance(value, (list, pystac.RangeSummary, dict)):
value = [value]
self._validate_ml_aoi_property(prop_name, value)
prop_name = field.alias or prop_name
super()._set_summary(prop_name, value)
else:
Expand All @@ -571,10 +604,25 @@ def __init__(self, collection: pystac.Collection):
ML_AOI_SummariesExtension.__init__(self, collection)
self.collection = collection
self.properties = collection.extra_fields
self.collection.set_self_href = ML_AOI_CollectionExtension.set_self_href # override hook

def __repr__(self) -> str:
return f"<ML_AOI_CollectionExtension Collection id={self.collection.id}>"

def set_self_href(self, href: Optional[str]) -> None:
"""Sets the absolute HREF that is represented by the ``rel == 'self'`` :class:`~pystac.Link`.
Adds the relevant ML-AOI role applicable for the Collection.
"""
pystac.Collection.set_self_href(self.collection, href)
ml_aoi_split = self.get_ml_aoi_property("split")
if not ml_aoi_split:
return
for link in self.collection.links:
if link.rel == pystac.RelType.SELF:
field_name = add_ml_aoi_prefix("split")
link.extra_fields[field_name] = ml_aoi_split[0]


class ML_AOI_ExtensionHooks(ExtensionHooks):
schema_uri: str = SCHEMA_URI
Expand Down
31 changes: 29 additions & 2 deletions tests/test_pystac_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def test_ml_aoi_pystac_collection_with_apply_method(
assert ML_AOI_SCHEMA_URI in collection.stac_extensions
collection.validate()
ml_aoi_col_json = collection.to_dict()
assert ml_aoi_col_json["summaries"] == [{"ml-aoi:split": "train"}]
assert ml_aoi_col_json["summaries"] == {"ml-aoi:split": ["train"]}


def test_ml_aoi_pystac_collection_with_field_property(
Expand All @@ -193,7 +193,30 @@ def test_ml_aoi_pystac_collection_with_field_property(
assert ML_AOI_SCHEMA_URI in collection.stac_extensions
collection.validate()
ml_aoi_col_json = collection.to_dict()
assert ml_aoi_col_json["summaries"] == [{"ml-aoi:split": "train"}]
assert ml_aoi_col_json["summaries"] == {"ml-aoi:split": ["train"]}


def test_ml_aoi_pystac_collection_self_link_role(
collection: pystac.Collection,
stac_validator: pystac.validation.stac_validator.JsonSchemaSTACValidator,
) -> None:
"""
Validate extending a STAC Collection with ML-AOI extension.
"""
assert ML_AOI_SCHEMA_URI not in collection.stac_extensions
ml_aoi_col = ML_AOI_Extension.ext(collection, add_if_missing=True)
ml_aoi_col.split = [ML_AOI_Split.TEST]
assert ML_AOI_SCHEMA_URI in collection.stac_extensions
ml_aoi_col.set_self_href("https://example.com/collections/test")
ml_aoi_col_json = collection.to_dict()
assert ml_aoi_col_json["summaries"] == {"ml-aoi:split": ["test"]}
ml_aoi_col_links = ml_aoi_col_json["links"]
assert any(link == {
"rel": "self",
"href": "https://example.com/collections/test",
"type": "application/json",
"ml-aoi:split": "test",
} for link in ml_aoi_col_links), ml_aoi_col_links


def test_ml_aoi_pystac_item_with_apply_method(
Expand Down Expand Up @@ -238,8 +261,12 @@ def test_ml_aoi_pystac_item_filter_assets(
assert ML_AOI_SCHEMA_URI not in item.stac_extensions
ml_aoi_item = cast(ML_AOI_ItemExtension, ML_AOI_Extension.ext(item, add_if_missing=True))
ml_aoi_item.split = ML_AOI_Split.TRAIN
assets = ml_aoi_item.get_assets()
assert len(assets) == 2 and list(assets) == ["label", "raster"]
assets = ml_aoi_item.get_assets(reference_grid=True)
assert len(assets) == 1 and "raster" in assets
assets = ml_aoi_item.get_assets(role=ML_AOI_Role.LABEL)
assert len(assets) == 1 and "label" in assets


if __name__ == "__main__":
Expand Down

0 comments on commit 9268ea0

Please sign in to comment.