diff --git a/json-schema/schema.json b/json-schema/schema.json index b5618e1..542fe50 100644 --- a/json-schema/schema.json +++ b/json-schema/schema.json @@ -2,7 +2,7 @@ "$schema": "http://json-schema.org/draft-07/schema#", "$id": "https://stac-extensions.github.io/ml-aoi/v0.1.0/schema.json#", "title": "ML AOI Extension", - "description": "ML AOI Extension for STAC Items.", + "description": "ML AOI Extension for STAC definitions.", "oneOf": [ { "$comment": "This is the schema for STAC Collections.", diff --git a/pystac_ml_aoi/extensions/ml_aoi.py b/pystac_ml_aoi/extensions/ml_aoi.py index 57a32c8..a0d89f5 100644 --- a/pystac_ml_aoi/extensions/ml_aoi.py +++ b/pystac_ml_aoi/extensions/ml_aoi.py @@ -61,7 +61,7 @@ ML_AOI_SCHEMA = json.load(schema_file) ML_AOI_SCHEMA_ID: SchemaName = get_args(SchemaName)[0] -ML_AOI_SCHEMA_URI: str = ML_AOI_SCHEMA["$id"] +ML_AOI_SCHEMA_URI: str = ML_AOI_SCHEMA["$id"].split("#")[0] ML_AOI_PREFIX = f"{ML_AOI_SCHEMA_ID}:" ML_AOI_PROPERTY = f"{ML_AOI_SCHEMA_ID}_".replace("-", "_") diff --git a/tests/test_pystac_extension.py b/tests/test_pystac_extension.py index 556ca89..6dba8ee 100644 --- a/tests/test_pystac_extension.py +++ b/tests/test_pystac_extension.py @@ -16,6 +16,8 @@ from pystac.extensions.label import LabelExtension, LabelTask, LabelType from pystac_ml_aoi.extensions.ml_aoi import ( ML_AOI_Extension, + ML_AOI_CollectionExtension, + ML_AOI_ItemExtension, ML_AOI_SCHEMA_PATH, ML_AOI_SCHEMA_URI, ML_AOI_Role, @@ -23,6 +25,21 @@ ) +EUROSAT_EXAMPLE_BASE_URL = "https://raw.githubusercontent.com/ai-extensions/stac-data-loader/0.5.0/data/EuroSAT" +EUROSAT_EXAMPLE_ASSET_ITEM_URL = ( + EUROSAT_EXAMPLE_BASE_URL + + "/stac/subset/train/item-42.json" +) +EUROSAT_EXAMPLE_ASSET_LABEL_URL = ( + EUROSAT_EXAMPLE_BASE_URL + + "/data/subset/ds/images/remote_sensing/otherDatasets/sentinel_2/label/Residential/Residential_1331.geojson" +) +EUROSAT_EXAMPLE_ASSET_RASTER_URL = ( + EUROSAT_EXAMPLE_BASE_URL + + "/data/subset/ds/images/remote_sensing/otherDatasets/sentinel_2/tif/Residential/Residential_1331.tif" +) + + @pytest.fixture(scope="session", name="stac_validator", autouse=True) def make_stac_ml_aoi_validator( request: pytest.FixtureRequest, @@ -38,8 +55,7 @@ def make_stac_ml_aoi_validator( validator = pystac.validation.RegisteredValidator.get_validator() validator = cast(pystac.validation.stac_validator.JsonSchemaSTACValidator, validator) validation_schema = json.loads(pystac.StacIO.default().read_text(ML_AOI_SCHEMA_PATH)) - validation_uri = ML_AOI_SCHEMA_URI.split("#")[0] - validator.schema_cache[validation_uri] = validation_schema + validator.schema_cache[ML_AOI_SCHEMA_URI] = validation_schema pystac.validation.RegisteredValidator.set_validator(validator) # apply globally to allow 'STACObject.validate()' return validator @@ -111,6 +127,23 @@ def make_base_stac_item() -> pystac.Item: ] } bbox = list(shapely.geometry.shape(geom).bounds) + asset_label = pystac.Asset( + href=EUROSAT_EXAMPLE_ASSET_LABEL_URL, + media_type=pystac.MediaType.GEOJSON, + roles=["data"], + extra_fields={ + "ml-aoi:role": ML_AOI_Role.LABEL, + } + ) + asset_raster = pystac.Asset( + href=EUROSAT_EXAMPLE_ASSET_RASTER_URL, + media_type=pystac.MediaType.GEOTIFF, + roles=["data"], + extra_fields={ + "ml-aoi:reference-grid": True, + "ml-aoi:role": ML_AOI_Role.FEATURE, + } + ) item = pystac.Item( id="EuroSAT-subset-train-sample-42-class-Residential", geometry=geom, @@ -118,13 +151,15 @@ def make_base_stac_item() -> pystac.Item: start_datetime=dt_parse("2015-06-27T10:25:31.456Z"), end_datetime=dt_parse("2017-06-14T00:00:00Z"), datetime=None, - properties={} + properties={}, + assets={"label": asset_label, "raster": asset_raster}, ) label_item = LabelExtension.ext(item, add_if_missing=True) label_item.apply( label_description="ml-aoi-test", - label_type=LabelType.RASTER, - label_tasks=[LabelTask.CLASSIFICATION], + label_type=LabelType.VECTOR, + label_tasks=[LabelTask.CLASSIFICATION, LabelTask.SEGMENTATION], + label_properties=["class"], ) return item @@ -193,5 +228,19 @@ def test_ml_aoi_pystac_item_with_field_property( assert ml_aoi_item_json["properties"]["ml-aoi:split"] == "train" +def test_ml_aoi_pystac_item_filter_assets( + item: pystac.Item, + stac_validator: pystac.validation.stac_validator.JsonSchemaSTACValidator, +) -> None: + """ + Validate extending a STAC Collection with ML-AOI extension. + """ + assert ML_AOI_SCHEMA_URI not in item.stac_extensions + ml_aoi_item = cast(ML_AOI_ItemExtension, ML_AOI_Extension.ext(item, add_if_missing=True)) + ml_aoi_item.split = ML_AOI_Split.TRAIN + assets = ml_aoi_item.get_assets(reference_grid=True) + assert len(assets) == 1 and "raster" in assets + + if __name__ == "__main__": unittest.main()